[Mlir-commits] [mlir] [mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support (PR #97297)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 1 07:16:44 PDT 2024


https://github.com/BRUCE11111 created https://github.com/llvm/llvm-project/pull/97297

None

>From 49e12f7409f366c135dc496beccac2d3f9f35362 Mon Sep 17 00:00:00 2001
From: "Xu, Xiaohui1" <xiaohui1.xu at intel.com>
Date: Mon, 1 Jul 2024 22:12:16 +0800
Subject: [PATCH] [mlir][vector] add tensor.concat, bitcast, expand_shape,
 collapse_shape vectorization support

---
 .../TransformOps/LinalgTransformOps.cpp       |   5 +-
 .../Linalg/Transforms/Vectorization.cpp       | 334 ++++++++++++++++++
 mlir/test/Dialect/Linalg/vectorization.mlir   | 192 ++++++++++
 3 files changed, 529 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
-            target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+             tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+             tensor::ConcatOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   return success();
 }
 
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+                                        Operation *inputOp,
+                                        ArrayRef<int64_t> inputVectorSizes,
+                                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(inputOp);
+  auto src = inputOp->getOperand(0);
+  auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+  auto result = inputOp->getResults()[0];
+  auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+  ArrayRef<int64_t> resultShape = resultType.getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  Location loc = inputOp->getLoc();
+
+  llvm::SmallVector<int64_t> srcVectorizedShape;
+  llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+  auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+                               ArrayRef<int64_t> &inputShape) {
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+    int64_t cur = 1, resultIdx = 0;
+    for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+      cur *= ss;
+      if (!isResultShapeBigger) {
+        // collapse
+        srcVectorizedShape.emplace_back(ss);
+        if (cur == retShape[resultIdx]) {
+          if (shapeScales.count(resultIdx)) {
+            srcVectorizedShape.back() *= shapeScales[resultIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      } else {
+        // expand
+        if (cur == retShape[resultIdx]) {
+          srcVectorizedShape.emplace_back(cur);
+          if (shapeScales.count(srcIdx)) {
+            srcVectorizedShape.back() *= shapeScales[srcIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      }
+    }
+  };
+  if (!inputVectorSizes.empty()) {
+    for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+      if (vs != resultShape[idx])
+        shapeScales[idx] = vs / resultShape[idx];
+    }
+
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+    if (!isResultShapeBigger) {
+      getVectorizeShape(resultShape, srcShape);
+    } else {
+      getVectorizeShape(srcShape, resultShape);
+    }
+  } else {
+    srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+  }
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, src,
+      inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+      padValue, false);
+
+  auto shapeCastType =
+      VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+                      resultType.getElementType());
+  vector::ShapeCastOp shapeCastOp =
+      rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+  // write
+  SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape) {
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  }
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+      inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+///   vector::TransferReadOp - Reads a vector from the source tensor
+///   vector.Bitcast - Bitcast the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+                                          tensor::BitcastOp bitCastOp,
+                                          ArrayRef<int64_t> inputVectorSizes,
+                                          SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(bitCastOp);
+
+  auto sourceType = bitCastOp.getSource().getType();
+  auto resultType = bitCastOp.getResult().getType();
+  auto resultShape = resultType.getShape();
+  if (inputVectorSizes.empty()) {
+    inputVectorSizes = resultShape;
+  }
+  Location loc = bitCastOp->getLoc();
+
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(sourceType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+  // bitcast
+  auto resultVectorType =
+      VectorType::get(inputVectorSizes, resultType.getElementType());
+  vector::BitCastOp vectorbitCastOp =
+      rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+  // write
+  llvm::SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape)
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  auto write =
+      createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+                               destSizes, inputVectorSizes, false);
+  newResults.push_back(write->getResults()[0]);
+  return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+///   Tensor::EmptyOp - The result tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+                                         tensor::ConcatOp concatOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(concatOp);
+
+  Location loc = concatOp.getLoc();
+  FailureOr<Value> dest =
+      tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+  if (failed(dest))
+    return failure();
+
+  auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+  if (!empty)
+    return failure();
+
+  // Compute the partial sums for the slice offsets.
+  auto dim = concatOp.getDim();
+  Value dimValue =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+  int64_t rank = concatOp.getResultType().getRank();
+  auto srcType =
+      mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+  // Construct the chain of insert_slice ops into the destination.
+  Value result = *dest;
+  Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, input);
+    SmallVector<int64_t> readMaskShape;
+    auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+    auto sourceShape = inputType.getShape();
+
+    readMaskShape.append(sourceShape.begin(), sourceShape.end());
+    Value readResult = vector::createReadOrMaskedRead(
+        rewriter, loc, input, sourceShape, padValue, false);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> indices(rank, zero);
+    indices[dim] = previous_offset;
+    result = rewriter
+                 .create<vector::TransferWriteOp>(
+                     loc, readResult, result, indices,
+                     rewriter.getMultiDimIdentityMap(rank))
+                 ->getResults()[0];
+    if (idx != concatOp.getNumOperands() - 1) {
+      auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+      previous_offset =
+          rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+    }
+  }
+
+  newResults.push_back(result);
+  return success();
+}
+
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
   return success();
 }
 
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = expandOp->getResultTypes()[0];
+  auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+  // check reassociation
+  llvm::SmallVector<int64_t> associateIndices;
+  for (auto &attr : expandOp.getReassociation()) {
+    for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+      associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+    }
+  }
+
+  if (llvm::any_of(associateIndices,
+                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
+    LDBG("Reassociation must be static: " << expandOp << "\n");
+    return failure();
+  }
+  // check input and output shape
+  if (!resultShape.hasStaticShape() ||
+      !expandOp.getSrcType().hasStaticShape()) {
+    LDBG("Input and output shape must be static: " << expandOp << "\n");
+    return failure();
+  }
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+
+  return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = bitCastOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcType = bitCastOp.getSource().getType();
+  auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+                                 ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = collapseOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcShapeType = collapseOp.getSrcType();
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << collapseOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  if (!inputVectorSizes.empty()) {
+    LDBG("Concat operation do not support specify inputVectorSizes: "
+         << concatOp << "\n");
+  }
+  for (auto x : concatOp->getOperands()) {
+    auto type = mlir::dyn_cast<ShapedType>(x.getType());
+    if (!type) {
+      LDBG("Operation type error: " << concatOp << "\n");
+      return failure();
+    }
+    if (!type.hasStaticShape()) {
+      LDBG("Type must be static: " << concatOp << "\n");
+      return failure();
+    }
+  }
+  auto dim = concatOp.getDim();
+  if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+    LDBG("Invalid dim: " << concatOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 /// Preconditions for scalable vectors.
 static LogicalResult
 vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::UnPackOp>([&](auto unpackOp) {
         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
       })
+      .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+        return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+      })
+      .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+        return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+                                                inputVectorSizes);
+      })
+      .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+        return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+      })
+      .Case<tensor::ConcatOp>([&](auto concatOp) {
+        return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
                                              inputVectorSizes, results);
           })
+          .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+            return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+                                      results);
+          })
+          .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+            return lowerTensorReshape(rewriter, collapseShapeOp,
+                                      inputVectorSizes, results);
+          })
+          .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+            return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+                                        results);
+          })
+          .Case<tensor::ConcatOp>([&](auto concatOp) {
+            return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+                                       results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   } 
  }
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+    // CHECK: %[[C01:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<8x8x32x32xi1> -> tensor<8x8x32x16xf32>
+    // CHECK: return %[[WRIT]] : tensor<8x8x32x16xf32>
+  %expanded = tensor.expand_shape %source [[0, 1], [2, 3]] output_shape [8, 8, 32, 16] : tensor<64x512xf32> into tensor<8x8x32x16xf32>
+   return %expanded : tensor<8x8x32x16xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.expand_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [8, 8, 32, 32] : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32>{
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x512xf32> to vector<8x8x32x16xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() :  tensor<8x8x32x16xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true, true, true]} : vector<8x8x32x16xf32>, tensor<8x8x32x16xf32>
+    // CHECK: return %[[WRIT]] : tensor<8x8x32x16xf32>
+    %expanded = tensor.expand_shape %source [[0, 1], [2, 3]] output_shape [8, 8, 32, 16] : tensor<64x512xf32> into tensor<8x8x32x16xf32>
+   return %expanded : tensor<8x8x32x16xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.expand_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_bitcast
+func.func @test_vectorize_bitcast(%source: tensor<64x512xi32>) -> tensor<64x512xf32> {
+    // CHECK: %[[C0i32:.*]] = arith.constant 0 : i32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xi32>
+    // CHECK: %[[CAST:.*]] = vector.bitcast %[[READ0]] : vector<64x1024xi32> to vector<64x1024xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C00:.*]]= arith.constant 0 : index
+    // CHECK: %[[C641:.*]] = arith.constant 64 : index
+    // CHECK: %[[C5121:.*]] = arith.constant 512 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C641]], %[[C5121]] : vector<64x1024xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %0 = tensor.bitcast %source : tensor<64x512xi32> to tensor<64x512xf32>
+  return %0 : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.bitcast"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_bitcast_no_vector_size
+func.func @test_vectorize_bitcast_no_vector_size(%source: tensor<64x512xi32>) -> tensor<64x512xf32> {
+    // CHECK: %[[C0i32:.*]] = arith.constant 0 : i32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[C0i32]] {in_bounds = [true, true]} : tensor<64x512xi32>, vector<64x512xi32>
+    // CHECK: %[[CAST:.*]] = vector.bitcast %[[READ0]] : vector<64x512xi32> to vector<64x512xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() :  tensor<64x512xf32>
+    // CHECK: %[[C00:.*]] = arith.constant 0 : index
+    // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %0 = tensor.bitcast %source : tensor<64x512xi32> to tensor<64x512xf32>
+  return %0 : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.bitcast"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_concat_no_vector_size
+func.func @test_vectorize_concat_no_vector_size(%arg0: tensor<64x512xf32>, %arg1:tensor<64x512xf32>) -> tensor<64x1024xf32> {
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x1024xf32>
+    // CHECK: %[[C1:.*]]= arith.constant 1 : index
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C0_0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32>
+    // CHECK: %[[C0_1:.*]]= arith.constant 0 : index
+    // CHECK: %[[WRIT0:.*]] = vector.transfer_write %[[READ0]], %[[EMPT]][{{.*}}, {{.*}}] : vector<64x512xf32>, tensor<64x1024xf32>
+    // CHECK: %[[DIM:.*]] = tensor.dim {{.*}}, {{.*}} : tensor<64x512xf32>
+    // CHECK: %[[ADD:.*]] = arith.addi %[[DIM]], {{.*}} : index
+    // CHECK: %[[C0_2:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32>
+    // CHECK: %[[C0_3:.*]]= arith.constant 0 : index
+    // CHECK: %[[WRIT1:.*]] = vector.transfer_write %[[READ1]], %[[WRIT0]][{{.*}}, {{.*}}] : vector<64x512xf32>, tensor<64x1024xf32>
+    // CHECK: return %[[WRIT1]] : tensor<64x1024xf32>
+  %0 = tensor.concat dim(1) %arg0, %arg1
+             : (tensor<64x512xf32>, tensor<64x512xf32>) -> tensor<64x1024xf32>
+  return %0 : tensor<64x1024xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.concat"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+



More information about the Mlir-commits mailing list