[Mlir-commits] [mlir] edd9515 - [mlir][VectorToGPU] First step to convert vector ops to GPU MMA ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 11 07:53:37 PDT 2021


Author: thomasraoux
Date: 2021-06-11T07:52:32-07:00
New Revision: edd9515bd125634f40ebc2e783d6a127345e7c0d

URL: https://github.com/llvm/llvm-project/commit/edd9515bd125634f40ebc2e783d6a127345e7c0d
DIFF: https://github.com/llvm/llvm-project/commit/edd9515bd125634f40ebc2e783d6a127345e7c0d.diff

LOG: [mlir][VectorToGPU] First step to convert vector ops to GPU MMA ops

This is the first step to convert vector ops to MMA operations in order to
target GPUs tensor core ops. This currently only support simple cases,
transpose and element-wise operation will be added later.

Differential Revision: https://reviews.llvm.org/D102962

Added: 
    mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
    mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a78b72894c49c..0cac29f2b62f7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -38,6 +38,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
 #include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ba5e27a2a87c6..47f328b6c1fb1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -513,6 +513,20 @@ def TosaToStandard : Pass<"tosa-to-standard"> {
   let constructor = "tosa::createTosaToStandard()";
 }
 
+//===----------------------------------------------------------------------===//
+// VectorToGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToGPU : FunctionPass<"convert-vector-to-gpu"> {
+  let summary = "Lower the operations from the vector dialect into the GPU "
+                "dialect";
+  let constructor = "mlir::createConvertVectorToGPUPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "gpu::GPUDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // VectorToSCF
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
new file mode 100644
index 0000000000000..5f6f7aa30ea39
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
@@ -0,0 +1,34 @@
+//===- VectorToGPU.h - Convert vector to GPU dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_
+#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class MLIRContext;
+class Pass;
+class FuncOp;
+class RewritePatternSet;
+
+/// Patterns to transform vector ops into a canonical form to convert to MMA
+/// matrix operations.
+void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns);
+
+/// Convert vector ops to MMA matrix operations. This will convert slice of
+/// operations that can be legally converted to MMA operations. The rest of the
+/// vector operations are left untouched.
+void convertVectorToMMAOps(FuncOp funcOp);
+
+/// Convert from vector to GPU ops.
+std::unique_ptr<Pass> createConvertVectorToGPUPass();
+
+} // namespace mlir
+
+#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 72cfb08405ace..66b3895244f73 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -29,5 +29,6 @@ add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToStandard)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
+add_subdirectory(VectorToGPU)
 add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)

diff  --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..484ad5451fc00
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRVectorToGPU
+  VectorToGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRGPU
+  MLIRLLVMIR
+  MLIRMemRef
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
new file mode 100644
index 0000000000000..227890bc1f661
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -0,0 +1,338 @@
+//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to GPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+
+#include "../PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+// Return true if the contract op can be convert to MMA matmul.
+static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
+  if (llvm::size(contract.masks()) != 0)
+    return false;
+
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n, k;
+  bindDims(contract.getContext(), m, n, k);
+  auto iteratorTypes = contract.iterator_types().getValue();
+  if (!(isParallelIterator(iteratorTypes[0]) &&
+        isParallelIterator(iteratorTypes[1]) &&
+        isReductionIterator(iteratorTypes[2])))
+    return false;
+
+  // The contract needs to represent a matmul to be able to convert to
+  // MMAMatrix matmul.
+  if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+    return false;
+
+  // Check that the size matches what is natively supported.
+  VectorType lhsType = contract.lhs().getType().cast<VectorType>();
+  VectorType rhsType = contract.rhs().getType().cast<VectorType>();
+  VectorType accType = contract.acc().getType().cast<VectorType>();
+
+  std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
+                                lhsType.getDimSize(1));
+  if (lhsType.getElementType().isInteger(8) &&
+      rhsType.getElementType().isInteger(8) &&
+      accType.getElementType().isInteger(32) &&
+      (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
+       dim == std::make_tuple(16, 8, 32)))
+    return true;
+
+  if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
+      (accType.getElementType().isF16() || accType.getElementType().isF32()) &&
+      (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
+       dim == std::make_tuple(16, 8, 16)))
+    return true;
+  return false;
+}
+
+// Return the stide for the dimension 0 of |type| if it is a memref and has a
+// constant stride.
+static llvm::Optional<int64_t>
+getMemrefConstantHorizontalStride(ShapedType type) {
+  auto memrefType = type.dyn_cast<MemRefType>();
+  if (!memrefType)
+    return false;
+  int64_t offset = 0;
+  SmallVector<int64_t, 2> strides;
+  if (failed(getStridesAndOffset(memrefType, strides, offset)))
+    return llvm::None;
+  if (strides[0] == ShapedType::kDynamicStrideOrOffset)
+    return llvm::None;
+  return strides[0];
+}
+
+// Return true if the transfer op can be converted to a MMA matrix load.
+static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
+  if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
+      readOp.getVectorType().getRank() != 2)
+    return false;
+  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
+    return false;
+  // TODO: Support transpose once it is added to GPU dialect ops.
+  if (!readOp.permutation_map().isMinorIdentity())
+    return false;
+  return true;
+}
+
+// Return true if the transfer op can be converted to a MMA matrix store.
+static bool
+transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
+  if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
+      writeOp.getVectorType().getRank() != 2)
+    return false;
+  if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
+    return false;
+  // TODO: Support transpose once it is added to GPU dialect ops.
+  if (!writeOp.permutation_map().isMinorIdentity())
+    return false;
+  return true;
+}
+
+static bool supportsMMaMatrixType(Operation *op) {
+  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadSupportsMMAMatrixType(transferRead);
+  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteSupportsMMAMatrixType(transferWrite);
+  if (auto contract = dyn_cast<vector::ContractionOp>(op))
+    return contractSupportsMMAMatrixType(contract);
+  return false;
+}
+
+// Analyze slice of operations based on convert op to figure out if the whole
+// slice can be converted to MMA operations.
+static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
+  auto hasVectorDest = [](Operation *op) {
+    return op->getNumResults() == 0 ||
+           llvm::any_of(op->getResultTypes(),
+                        [](Type t) { return t.isa<VectorType>(); });
+  };
+  SetVector<Operation *> opToConvert;
+  op->walk([&](vector::ContractionOp contract) {
+    if (opToConvert.contains(contract.getOperation()))
+      return;
+    SetVector<Operation *> dependentOps =
+        getSlice(contract, hasVectorDest, hasVectorDest);
+    // If any instruction cannot use MMA matrix type drop the whole
+    // chaine. MMA matrix are stored in an opaque type so they cannot be used
+    // by all operations.
+    if (llvm::any_of(dependentOps,
+                     [](Operation *op) { return !supportsMMaMatrixType(op); }))
+      return;
+    opToConvert.insert(dependentOps.begin(), dependentOps.end());
+  });
+  return opToConvert;
+}
+
+namespace {
+// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
+// to MMA matmul.
+struct PrepareContractToGPUMMA
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
+
+    // Set up the parallel/reduction structure in right form.
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+    AffineExpr m, n, k;
+    bindDims(rewriter.getContext(), m, n, k);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto iteratorTypes = op.iterator_types().getValue();
+    SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+    if (!(isParallelIterator(iteratorTypes[0]) &&
+          isParallelIterator(iteratorTypes[1]) &&
+          isReductionIterator(iteratorTypes[2])))
+      return failure();
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    //
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      // This is the classical row-major matmul, nothing to do.
+      return failure();
+    }
+    if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      std::swap(rhs, lhs);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      std::swap(rhs, lhs);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      std::swap(lhs, rhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      std::swap(lhs, rhs);
+    } else {
+      return failure();
+    }
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        op, lhs, rhs, res,
+        rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
+        op.iterator_types());
+    return success();
+  }
+};
+
+// Merge transpose op into the transfer read op. Transpose are not supported on
+// MMA types but MMA load can transpose the matrix when loading.
+struct CombineTransferReadOpTranspose final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
+    if (!transferReadOp)
+      return failure();
+    if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
+      return failure();
+    SmallVector<int64_t, 2> perm;
+    op.getTransp(perm);
+    SmallVector<unsigned, 2> permU;
+    for (int64_t o : perm)
+      permU.push_back(unsigned(o));
+    AffineMap permutationMap =
+        AffineMap::getPermutationMap(permU, op.getContext());
+    AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
+        newMap, transferReadOp.padding(), transferReadOp.mask(),
+        transferReadOp.in_boundsAttr());
+    return success();
+  }
+};
+
+} // namespace
+
+// MMA types have 
diff erent layout based on how they are used in matmul ops.
+// Figure the right layout to use by looking at Transfer op uses.
+// TODO: Change the GPU dialect to abstract the layout at the this level and
+// only care about it during lowering to NVVM.
+static const char *inferFragType(vector::TransferReadOp op) {
+  for (Operation *users : op->getUsers()) {
+    auto contract = dyn_cast<vector::ContractionOp>(users);
+    if (!contract)
+      continue;
+    if (contract.lhs() == op.getResult())
+      return "AOp";
+    if (contract.rhs() == op.getResult())
+      return "BOp";
+  }
+  return "COp";
+}
+
+static void convertTransferReadOp(vector::TransferReadOp op,
+                                  llvm::DenseMap<Value, Value> &valueMapping) {
+  assert(transferReadSupportsMMAMatrixType(op));
+  Optional<int64_t> stride =
+      getMemrefConstantHorizontalStride(op.getShapedType());
+  assert(stride);
+  const char *fragType = inferFragType(op);
+  gpu::MMAMatrixType type =
+      gpu::MMAMatrixType::get(op.getVectorType().getShape(),
+                              op.getVectorType().getElementType(), fragType);
+  OpBuilder b(op);
+  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+      op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
+  valueMapping[op.getResult()] = load;
+}
+
+static void convertTransferWriteOp(vector::TransferWriteOp op,
+                                   llvm::DenseMap<Value, Value> &valueMapping) {
+  assert(transferWriteSupportsMMAMatrixType(op));
+  Optional<int64_t> stride =
+      getMemrefConstantHorizontalStride(op.getShapedType());
+  assert(stride);
+  OpBuilder b(op);
+  Value matrix = valueMapping.find(op.vector())->second;
+  b.create<gpu::SubgroupMmaStoreMatrixOp>(
+      op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
+  op.erase();
+}
+
+static void convertContractOp(vector::ContractionOp op,
+                              llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  Value opA = valueMapping.find(op.lhs())->second;
+  Value opB = valueMapping.find(op.rhs())->second;
+  Value opC = valueMapping.find(op.acc())->second;
+  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
+                                                     opA, opB, opC);
+  valueMapping[op.getResult()] = matmul;
+}
+
+namespace mlir {
+
+void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
+  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
+      patterns.getContext());
+}
+
+void convertVectorToMMAOps(FuncOp funcOp) {
+  SetVector<Operation *> ops = getOpToConvert(funcOp);
+  llvm::DenseMap<Value, Value> valueMapping;
+  for (Operation *op : ops) {
+    if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
+      convertTransferReadOp(transferRead, valueMapping);
+    } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
+      convertTransferWriteOp(transferWrite, valueMapping);
+    } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
+      convertContractOp(contractOp, valueMapping);
+    }
+  }
+}
+
+} // namespace mlir
+namespace {
+
+struct ConvertVectorToGPUPass
+    : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
+  void runOnFunction() override {
+    RewritePatternSet patterns(getFunction().getContext());
+    populatePrepareVectorToMMAPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+
+    convertVectorToMMAOps(getFunction());
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
+  return std::make_unique<ConvertVectorToGPUPass>();
+}

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
new file mode 100644
index 0000000000000..5005cc6c6b228
--- /dev/null
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -convert-vector-to-gpu -canonicalize | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @matmul
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+  %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = constant 0 : index
+  %cst = constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}
+
+// Negative test until scf.for support is added.
+// CHECK-LABEL: func @matmul_loop
+//       CHECK:   vector.contract
+func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
+  %c0 = constant 0 : index
+  %c128 = constant 128 : index
+  %c32 = constant 32 : index
+  %cst = constant 0.000000e+00 : f16
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16>
+  %14 = scf.for %arg17 = %c0 to %c128 step %c32 iter_args(%arg18 = %C) -> (vector<16x16xf16>) {
+    %17 = vector.transfer_read %arg0[%c0, %arg17], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16>
+    %18 = vector.transfer_read %arg1[%arg17, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16>
+    %19 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %17, %18, %arg18 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+    scf.yield %19 : vector<16x16xf16>
+  }
+  vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16>
+  return
+}


        


More information about the Mlir-commits mailing list