[Mlir-commits] [mlir] [mlir][vector] Add registry for tensor dialect (PR #108045)

Longsheng Mou llvmlistbot at llvm.org
Tue Sep 10 09:09:24 PDT 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/108045

>From bd4faffea911f3caaffca4049ba0fa7b5d333eb4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Tue, 10 Sep 2024 23:28:33 +0800
Subject: [PATCH] [mlir][vector] Add registry for tensor dialect

This patch adds registry for tensor dialect, which fixes
a crash when transfer_write to dynamic tensor type.
---
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp           |  2 ++
 mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir  | 10 ++++++++++
 2 files changed, 12 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 842d239cf6a512..4623b9667998cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -45,6 +46,7 @@ struct ConvertVectorToLLVMPass
     registry.insert<LLVM::LLVMDialect>();
     registry.insert<arith::ArithDialect>();
     registry.insert<memref::MemRefDialect>();
+    registry.insert<tensor::TensorDialect>();
     if (armNeon)
       registry.insert<arm_neon::ArmNeonDialect>();
     if (armSVE)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7ac49c5f02347e..bd14823cea50ab 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2521,6 +2521,16 @@ func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<
 
 // -----
 
+// CHECK-LABEL: func @transfer_write_tensor
+//       CHECK:   vector.transfer_write
+func.func @transfer_write_tensor(%arg0: vector<4xf32>,%arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_write %arg0, %arg1[%c0] : vector<4xf32>, tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 func.func @genbool_0d_f() -> vector<i1> {
   %0 = vector.constant_mask [0] : vector<i1>
   return %0 : vector<i1>



More information about the Mlir-commits mailing list