[Mlir-commits] [mlir] 59058c4 - [mlir][vector] Add operations used for Vector distribution

Thomas Raoux llvmlistbot at llvm.org
Thu Apr 14 20:53:57 PDT 2022


Author: Thomas Raoux
Date: 2022-04-15T03:47:52Z
New Revision: 59058c441a9ba421b8f45cf1482544fd72ecb558

URL: https://github.com/llvm/llvm-project/commit/59058c441a9ba421b8f45cf1482544fd72ecb558
DIFF: https://github.com/llvm/llvm-project/commit/59058c441a9ba421b8f45cf1482544fd72ecb558.diff

LOG: [mlir][vector] Add operations used for Vector distribution

Add vector op warp_execute_on_lane_0 that will be used to do incremental
vector distribution in order to target warp level vector programming for
architectures with GPU-like SIMT programming model.
The idea behing the op is discussed further on discourse:
https://discourse.llvm.org/t/vector-vector-distribution-large-vector-to-small-vector/1983/23

Differential Revision: https://reviews.llvm.org/D123703

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/CMakeLists.txt
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index b5e9f25c710e2..39c6353cadffb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3e9ad30cd8bc3..76daf9e387c2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,6 +13,7 @@
 #ifndef VECTOR_OPS
 #define VECTOR_OPS
 
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
@@ -2539,4 +2540,139 @@ def Vector_ScanOp :
   let hasVerifier = 1;
 }
 
+def Vector_YieldOp : Vector_Op<"yield", [
+    NoSideEffect, ReturnLike, Terminator]> {
+  let summary = "Terminates and yields values from vector regions.";
+  let description = [{
+    "vector.yield" yields an SSA value from the Vector dialect op region and
+    terminates the regions. The semantics of how the values are yielded is
+    defined by the parent operation.
+    If "vector.yield" has any operands, the operands must correspond to the
+    parent operation's results.
+    If the parent operation defines no value the vector.yield may be omitted
+    when printing the region.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [
+    OpBuilder<(ins), [{ /* nothing to do */ }]>,
+  ];
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
+def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>,
+       SingleBlockImplicitTerminator<"vector::YieldOp">,
+       RecursiveSideEffects]> {
+  let summary = "Executes operations in the associated region on lane #0 of a"
+                "GPU SIMT warp";
+  let description = [{
+    `warp_execute_on_lane_0` is an operation used to bridge the gap between
+    vector programming and GPU SIMT programming model. It allows to trivially
+    convert a region of vector code meant to run on a GPU warp into a valid SIMT
+    region and then allows incremental transformation to distribute vector
+    operations on the SIMT lane.
+
+    Any code present in the region would only be executed on first lane
+    based on the `laneid` operand. The `laneid` operand is an integer ID between
+    [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in
+    a warp.
+
+    Operands are vector values distributed on all lanes that may be used by
+    the single lane execution. The matching region argument is a vector of all
+    the values of those lanes available to the single active lane. The
+    distributed dimension is implicit based on the shape of the operand and
+    argument. In the future this may be described by an affine map.
+
+    Return values are distributed on all lanes using laneId as index. The
+    vector is distributed based on the shape ratio between the vector type of
+    the yield and the result type.
+    If the shapes are the same this means the value is broadcasted to all lanes.
+    In the future the distribution can be made more explicit using affine_maps
+    and will support having multiple Ids.
+
+    Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy
+    between lane0 and the lanes of the warp. When distributing a vector
+    from lane0 to all the lanes, the data are distributed in a block cyclic way.
+
+    During lowering values passed as operands and return value need to be
+    visible to 
diff erent lanes within the warp. This would usually be done by
+    going through memory.
+
+    The region is *not* isolated from above. For values coming from the parent
+    region not going through operands only the lane 0 value will be accesible so
+    it generally only make sense for uniform values.
+
+    Example:
+    ```
+    vector.warp_execute_on_lane_0 (%laneid)[32] {
+      ...
+    }
+    ```
+
+    This may be lowered to an scf.if region as below:
+    ```
+      %cnd = arith.cmpi eq, %laneid, %c0 : index
+      scf.if %cnd {
+         ...
+      }
+    ```
+
+    When the region has operands and/or return values:
+    ```
+    %0 = vector.warp_execute_on_lane_0(%laneid)[32]
+    args(%v0 : vector<4xi32>) -> (vector<1xf32>) {
+    ^bb0(%arg0 : vector<128xi32>) :
+      ...
+      vector.yield %1 : vector<32xf32>
+    }
+    ```
+
+    values at the region boundary would go through memory:
+    ```
+    %tmp0 = memreg.alloc() : memref<32xf32, 3>
+    %tmp1 = memreg.alloc() : memref<32xf32, 3>
+    %cnd = arith.cmpi eq, %laneid, %c0 : index
+    vector.store %v0, %tmp0[%laneid] : memref<32xf32>, vector<1xf32>
+    warp_sync
+    scf.if %cnd {
+      %arg0 = vector.load %tmp0[%c0] : memref<32xf32>, vector<32xf32>
+      ...
+      vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32>
+    }
+    warp_sync
+    %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32>
+    ```
+
+  }];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+  let arguments = (ins Index:$laneid, I64Attr:$warp_size,
+                       Variadic<AnyType>:$args);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$warpRegion);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$laneid, "int64_t":$warpSize)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+                   "int64_t":$warpSize)>,
+    // `blockArgTypes` are 
diff erent than `args` types as they are they
+    // represent all the `args` instances visibile to lane 0. Therefore we need
+    // to explicit pass the type.
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+                   "int64_t":$warpSize, "ValueRange":$args,
+                   "TypeRange":$blockArgTypes)>
+  ];
+
+  let extraClassDeclaration = [{
+    bool isDefinedOutsideOfRegion(Value value) {
+      return !getRegion().isAncestor(value.getParentRegion());
+    }
+  }];
+}
+
 #endif // VECTOR_OPS

diff  --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 59647539d10e9..17380bd049db5 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVector
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
+  MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
   MLIRDialectUtils
   MLIRIR

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 940f9262f1472..af174601da570 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4589,6 +4589,184 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
+//===----------------------------------------------------------------------===//
+// WarpExecuteOnLane0Op
+//===----------------------------------------------------------------------===//
+
+void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
+  p << "(" << getLaneid() << ")";
+
+  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
+  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
+  p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
+
+  if (!getArgs().empty())
+    p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
+  if (!getResults().empty())
+    p << " -> (" << getResults().getTypes() << ')';
+  p << " ";
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/true,
+                /*printBlockTerminators=*/!getResults().empty());
+  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
+}
+
+ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
+                                        OperationState &result) {
+  // Create the region.
+  result.regions.reserve(1);
+  Region *warpRegion = result.addRegion();
+
+  auto &builder = parser.getBuilder();
+  OpAsmParser::UnresolvedOperand laneId;
+
+  // Parse predicate operand.
+  if (parser.parseLParen() || parser.parseRegionArgument(laneId) ||
+      parser.parseRParen())
+    return failure();
+
+  int64_t warpSize;
+  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
+      parser.parseRSquare())
+    return failure();
+  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
+                                                        builder.getContext())),
+                      builder.getI64IntegerAttr(warpSize));
+
+  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
+    return failure();
+
+  llvm::SMLoc inputsOperandsLoc;
+  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
+  SmallVector<Type> inputTypes;
+  if (succeeded(parser.parseOptionalKeyword("args"))) {
+    if (parser.parseLParen())
+      return failure();
+
+    inputsOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(inputsOperands) ||
+        parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+      return failure();
+  }
+  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
+                             result.operands))
+    return failure();
+
+  // Parse optional results type list.
+  if (parser.parseOptionalArrowTypeList(result.types))
+    return failure();
+  // Parse the region.
+  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
+                         /*argTypes=*/{}))
+    return failure();
+  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+}
+
+void WarpExecuteOnLane0Op::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  if (index.hasValue()) {
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // The warp region is always executed
+  regions.push_back(RegionSuccessor(&getWarpRegion()));
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+                                 TypeRange resultTypes, Value laneId,
+                                 int64_t warpSize) {
+  build(builder, result, resultTypes, laneId, warpSize,
+        /*operands=*/llvm::None, /*argTypes=*/llvm::None);
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+                                 TypeRange resultTypes, Value laneId,
+                                 int64_t warpSize, ValueRange args,
+                                 TypeRange blockArgTypes) {
+  result.addOperands(laneId);
+  result.addAttribute(getAttributeNames()[0],
+                      builder.getI64IntegerAttr(warpSize));
+  result.addTypes(resultTypes);
+  result.addOperands(args);
+  assert(args.size() == blockArgTypes.size());
+  OpBuilder::InsertionGuard guard(builder);
+  Region *warpRegion = result.addRegion();
+  Block *block = builder.createBlock(warpRegion);
+  for (auto it : llvm::zip(blockArgTypes, args))
+    block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
+}
+
+/// Helper check if the distributed vector type is consistent with the expanded
+/// type and distributed size.
+static LogicalResult verifyDistributedType(Type expanded, Type distributed,
+                                           int64_t warpSize, Operation *op) {
+  // If the types matches there is no distribution.
+  if (expanded == distributed)
+    return success();
+  auto expandedVecType = expanded.dyn_cast<VectorType>();
+  auto distributedVecType = distributed.dyn_cast<VectorType>();
+  if (!expandedVecType || !distributedVecType)
+    return op->emitOpError("expected vector type for distributed operands.");
+  if (expandedVecType.getRank() != distributedVecType.getRank() ||
+      expandedVecType.getElementType() != distributedVecType.getElementType())
+    return op->emitOpError(
+        "expected distributed vectors to have same rank and element type.");
+  bool foundDistributedDim = false;
+  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
+    if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
+      continue;
+    if (expandedVecType.getDimSize(i) ==
+        distributedVecType.getDimSize(i) * warpSize) {
+      if (foundDistributedDim)
+        return op->emitOpError()
+               << "expected only one dimension to be distributed from "
+               << expandedVecType << " to " << distributedVecType;
+      foundDistributedDim = true;
+      continue;
+    }
+    return op->emitOpError() << "incompatible distribution dimensions from "
+                             << expandedVecType << " to " << distributedVecType;
+  }
+  return success();
+}
+
+LogicalResult WarpExecuteOnLane0Op::verify() {
+  if (getArgs().size() != getWarpRegion().getNumArguments())
+    return emitOpError(
+        "expected same number op arguments and block arguments.");
+  auto yield =
+      cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+  if (yield.getNumOperands() != getNumResults())
+    return emitOpError(
+        "expected same number of yield operands and return values.");
+  int64_t warpSize = getWarpSize();
+  for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
+    if (failed(verifyDistributedType(std::get<0>(it).getType(),
+                                     std::get<1>(it).getType(), warpSize,
+                                     getOperation())))
+      return failure();
+  }
+  for (auto it : llvm::zip(yield.getOperands(), getResults())) {
+    if (failed(verifyDistributedType(std::get<0>(it).getType(),
+                                     std::get<1>(it).getType(), warpSize,
+                                     getOperation())))
+      return failure();
+  }
+  return success();
+}
+
+bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
+  return succeeded(
+      verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index f60d2b103b882..e3e01b993df36 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1528,3 +1528,78 @@ func @invalid_splat(%v : f32) {
   vector.splat %v : memref<8xf32>
   return
 }
+
+// -----
+
+func @warp_wrong_num_outputs(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected same number of yield operands and return values.}}
+  %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) {
+  }
+  return
+}
+
+// -----
+
+func @warp_wrong_num_inputs(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected same number op arguments and block arguments.}}
+  vector.warp_execute_on_lane_0(%laneid)[64] {
+  ^bb0(%arg0 : vector<128xi32>) :
+  }
+  return
+}
+
+// -----
+
+func @warp_wrong_return_distribution(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}}
+  %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) {
+    %0 = arith.constant dense<2>: vector<128xi32>
+    vector.yield %0 : vector<128xi32>
+  }
+  return
+}
+
+
+// -----
+
+func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}}
+  vector.warp_execute_on_lane_0(%laneid)[64]
+  args(%v0 : vector<4xi32>) {
+   ^bb0(%arg0 : vector<128xi32>) :
+  }
+  return
+}
+
+// -----
+
+func @warp_2_distributed_dims(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}}
+  %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
+    %0 = arith.constant dense<2>: vector<128x128xi32>
+    vector.yield %0 : vector<128x128xi32>
+  }
+  return
+}
+
+// -----
+
+func @warp_mismatch_rank(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
+    %0 = arith.constant dense<2>: vector<128xi32>
+    vector.yield %0 : vector<128xi32>
+  }
+  return
+}
+
+// -----
+
+func @warp_mismatch_rank(%laneid: index) {
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
+    %0 = arith.constant dense<2>: vector<128xi32>
+    vector.yield %0 : vector<128xi32>
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 43b38efb242eb..3db28eb912def 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -745,3 +745,30 @@ func @vector_splat_0d(%a: f32) -> vector<f32> {
   %0 = vector.splat %a : vector<f32>
   return %0 : vector<f32>
 }
+
+// CHECK-LABEL:   func @warp_execute_on_lane_0(
+func @warp_execute_on_lane_0(%laneid: index) {
+//  CHECK-NEXT:     vector.warp_execute_on_lane_0(%{{.*}})[32] {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+//  CHECK-NEXT:     }
+  }
+//  CHECK-NEXT:     return
+  return
+}
+
+// CHECK-LABEL:   func @warp_operand_result(
+func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) {
+//  CHECK-NEXT:     %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) {
+  %2 = vector.warp_execute_on_lane_0(%laneid)[32]
+  args(%v0 : vector<4xi32>) -> (vector<4xi32>) {
+   ^bb0(%arg0 : vector<128xi32>) :
+    %0 = arith.constant dense<2>: vector<128xi32>
+    %1 = arith.addi %arg0, %0 : vector<128xi32>
+//       CHECK:       vector.yield %{{.*}} : vector<128xi32>
+    vector.yield %1 : vector<128xi32>
+//  CHECK-NEXT:     }
+  }
+  return %2 : vector<4xi32>
+}
+
+

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 656a089082ab3..65096ccd68917 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2963,6 +2963,7 @@ cc_library(
     deps = [
         ":ArithmeticDialect",
         ":ArithmeticUtils",
+        ":ControlFlowInterfaces",
         ":DialectUtils",
         ":IR",
         ":MemRefDialect",
@@ -7275,6 +7276,7 @@ td_library(
     srcs = ["include/mlir/Dialect/Vector/IR/VectorOps.td"],
     includes = ["include"],
     deps = [
+        ":ControlFlowInterfacesTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":OpBaseTdFiles",
         ":SideEffectInterfacesTdFiles",


        


More information about the Mlir-commits mailing list