[Mlir-commits] [mlir] cf21667 - [mlir][linalg] Add vectorization for linalg on tensor ops

Thomas Raoux llvmlistbot at llvm.org
Tue Dec 29 09:03:04 PST 2020


Author: Thomas Raoux
Date: 2020-12-29T09:02:23-08:00
New Revision: cf216670a0bd1f2ce561a315e00649740f117e1c

URL: https://github.com/llvm/llvm-project/commit/cf216670a0bd1f2ce561a315e00649740f117e1c
DIFF: https://github.com/llvm/llvm-project/commit/cf216670a0bd1f2ce561a315e00649740f117e1c.diff

LOG: [mlir][linalg] Add vectorization for linalg on tensor ops

Support vectorization of linalg ops using tensor inputs/outputs.

Differential Revision: https://reviews.llvm.org/D93890

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23e452df9184..2a1d4cd2ef57 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -119,34 +119,38 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(genericOp.getRegion());
 }
 
-static VectorType extractVectorTypeFromScalarView(Value v) {
-  MemRefType mt = v.getType().cast<MemRefType>();
-  return mt.getShape().empty()
-             ? VectorType()
-             : VectorType::get(mt.getShape(), mt.getElementType());
+static VectorType extractVectorTypeFromShapedValue(Value v) {
+  auto st = v.getType().cast<ShapedType>();
+  if (st.isa<MemRefType>() && st.getShape().empty())
+    return VectorType();
+  return VectorType::get(st.getShape(), st.getElementType());
 }
 
-static Value transferReadVector(OpBuilder &builder, Value memref) {
+static Value transferReadVector(OpBuilder &builder, Value source) {
   edsc::ScopedContext scope(builder);
-  auto memrefType = memref.getType().cast<MemRefType>();
-  if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
-    SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
-    return vector_transfer_read(vectorType, memref, indices);
+  auto shapedType = source.getType().cast<ShapedType>();
+  if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
+    SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
+    return vector_transfer_read(vectorType, source, indices);
   }
-  return std_load(memref);
+  return std_load(source);
 }
 
-static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
+static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
   edsc::ScopedContext scope(builder);
-  auto memrefType = memref.getType().cast<MemRefType>();
-  if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
-    SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+  Operation *write;
+  auto shapedType = dest.getType().cast<ShapedType>();
+  if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
+    SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
     if (vectorType != value.getType())
       value = vector_broadcast(vectorType, value);
-    vector_transfer_write(value, memref, indices);
+    write = vector_transfer_write(value, dest, indices);
   } else {
-    std_store(value, memref);
+    write = std_store(value, dest);
   }
+  if (!write->getResults().empty())
+    return write->getResult(0);
+  return Value();
 }
 
 namespace {
@@ -167,10 +171,12 @@ class GenericVectorizer {
   void vectorize(Operation &scalarOp) {
     auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
     if (yieldOp) {
-      for (auto outputAndMemref :
-           llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
-        Value vectorValue = vectorize(std::get<0>(outputAndMemref));
-        transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
+      for (auto outputs : llvm::enumerate(yieldOp.values())) {
+        Value vectorValue = vectorize(outputs.value());
+        Value result = transferWriteVector(builder, vectorValue,
+                                           generic.getOutput(outputs.index()));
+        if (result)
+          results.push_back(result);
       }
       return;
     }
@@ -182,6 +188,8 @@ class GenericVectorizer {
     }
   }
 
+  llvm::ArrayRef<Value> getResults() { return results; }
+
 private:
   // Transforms a scalar value into its vectorized counterpart, recursively
   // vectorizing operations as necessary using the underlying builder.
@@ -261,6 +269,7 @@ class GenericVectorizer {
   OpBuilder &builder;
   linalg::GenericOp generic;
   llvm::DenseMap<Value, Value> valueCache;
+  SmallVector<Value, 8> results;
 };
 } // namespace
 
@@ -271,6 +280,8 @@ static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
   for (Operation &scalarOp : op.region().front()) {
     vectorizer.vectorize(scalarOp);
   }
+  if (!op->getResults().empty())
+    op->replaceAllUsesWith(vectorizer.getResults());
 }
 
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -331,32 +342,14 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
   auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value viewA = linalgOp.getInput(0);
-  Value viewB = linalgOp.getInput(1);
-  Value viewC = linalgOp.getOutputBuffer(0);
-  VectorType vtA = extractVectorTypeFromScalarView(viewA);
-  VectorType vtB = extractVectorTypeFromScalarView(viewB);
-  VectorType vtC = extractVectorTypeFromScalarView(viewC);
-  Value zero = std_constant_index(0);
-  SmallVector<Value, 4> indicesA, indicesB, indicesC;
-  if (vtA)
-    indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
-  if (vtB)
-    indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
-  if (vtC)
-    indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
-  Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
-                : std_load(viewA, indicesA).value;
-  Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
-                : std_load(viewB, indicesB).value;
-  Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
-                : std_load(viewC, indicesC).value;
+  Value a = transferReadVector(builder, linalgOp.getInput(0));
+  Value b = transferReadVector(builder, linalgOp.getInput(1));
+  Value c = transferReadVector(builder, linalgOp.getOutput(0));
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  if (vtC)
-    vector_transfer_write(res, viewC, indicesC);
-  else
-    std_store(res, viewC, indicesC);
+  Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
+  if (writeResult)
+    linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a6a353c97977..5acc5f5a7ce2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2039,28 +2039,28 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
 
 /// Builder that sets padding to zero.
 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vector, Value memref, ValueRange indices,
+                           VectorType vector, Value source, ValueRange indices,
                            AffineMap permutationMap,
                            ArrayRef<bool> maybeMasked) {
-  Type elemType = memref.getType().cast<MemRefType>().getElementType();
+  Type elemType = source.getType().cast<ShapedType>().getElementType();
   Value padding = builder.create<ConstantOp>(result.location, elemType,
                                              builder.getZeroAttr(elemType));
   if (maybeMasked.empty())
-    return build(builder, result, vector, memref, indices, permutationMap,
+    return build(builder, result, vector, source, indices, permutationMap,
                  padding, ArrayAttr());
   ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
-  build(builder, result, vector, memref, indices, permutationMap, padding,
+  build(builder, result, vector, source, indices, permutationMap, padding,
         maskedArrayAttr);
 }
 
 /// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
 /// (resp. zero).
 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value memref,
+                           VectorType vectorType, Value source,
                            ValueRange indices, ArrayRef<bool> maybeMasked) {
   auto permMap = getTransferMinorIdentityMap(
-      memref.getType().cast<MemRefType>(), vectorType);
-  build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
+      source.getType().cast<ShapedType>(), vectorType);
+  build(builder, result, vectorType, source, indices, permMap, maybeMasked);
 }
 
 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
@@ -2251,7 +2251,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             ArrayRef<bool> maybeMasked) {
   auto vectorType = vector.getType().cast<VectorType>();
   auto permMap = getTransferMinorIdentityMap(
-      source.getType().cast<MemRefType>(), vectorType);
+      source.getType().cast<ShapedType>(), vectorType);
   if (maybeMasked.empty())
     return build(builder, result, vector, source, indices, permMap,
                  ArrayAttr());
@@ -2327,7 +2327,7 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
 }
 
 static LogicalResult verify(TransferWriteOp op) {
-  // Consistency of elemental types in memref and vector.
+  // Consistency of elemental types in shape and vector.
   ShapedType shapedType = op.getShapedType();
   VectorType vectorType = op.getVectorType();
   auto permutationMap = op.permutation_map();

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 6019dde49983..cf49e85a4a28 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -189,7 +189,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
 //       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
 //       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
-//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> 
+//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
 //       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
 //       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
 //       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -209,3 +209,108 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
 //       CHECK:   vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
 //       CHECK:   vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
 //       CHECK:   vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+
+func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
+  %arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
+  %i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) {
+  %c1_f32 = constant 1.0 : f32
+  %r:10 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+  ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>)
+  outs(
+    %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>) {
+  ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+    %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
+    %arg14 : f32):
+    %6 = addf %arg4, %arg6 : f32
+    %7 = cmpf "ogt", %arg3, %arg6 : f32
+    %8 = constant 2.0 : f32
+    %9 = divf %arg5, %i : f32
+    %10 = exp2 %arg5 : f32
+    %11 = mulf %arg5, %8 : f32
+    %12 = rsqrt %arg5 : f32
+    %13 = select %7, %arg5, %arg6 : f32
+    %14 = subf %arg5, %arg4 : f32
+    %15 = tanh %arg5 : f32
+    linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
+      f32, f32, f32, f32, f32, f32, f32, f32
+  } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+  return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
+    tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+}
+
+// CHECK-LABEL: func @generic_vectorize_tensor
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
+//   CHECK-DAG:   %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+//   CHECK-DAG:   %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+//       CHECK:   %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+//       CHECK:   %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
+//       CHECK:   %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
+//       CHECK:   %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
+//       CHECK:   %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
+//       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
+//       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
+//       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
+//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
+//       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
+//       CHECK:   %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
+//       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
+
+func @matmul_tensors(
+  %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
+    -> tensor<8x12xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
+                     outs(%arg2: tensor<8x12xf32>)
+    -> tensor<8x12xf32>
+  return %0 : tensor<8x12xf32>
+}
+
+// CHECK-LABEL: func @matmul_tensors
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
+//  CHECK-SAME:  %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
+//       CHECK:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
+//       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
+//       CHECK:   %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
+//       CHECK:   return %[[W]] : tensor<8x12xf32>


        


More information about the Mlir-commits mailing list