[Mlir-commits] [mlir] 59058c4 - [mlir][vector] Add operations used for Vector distribution
Thomas Raoux
llvmlistbot at llvm.org
Thu Apr 14 20:53:57 PDT 2022
Author: Thomas Raoux
Date: 2022-04-15T03:47:52Z
New Revision: 59058c441a9ba421b8f45cf1482544fd72ecb558
URL: https://github.com/llvm/llvm-project/commit/59058c441a9ba421b8f45cf1482544fd72ecb558
DIFF: https://github.com/llvm/llvm-project/commit/59058c441a9ba421b8f45cf1482544fd72ecb558.diff
LOG: [mlir][vector] Add operations used for Vector distribution
Add vector op warp_execute_on_lane_0 that will be used to do incremental
vector distribution in order to target warp level vector programming for
architectures with GPU-like SIMT programming model.
The idea behing the op is discussed further on discourse:
https://discourse.llvm.org/t/vector-vector-distribution-large-vector-to-small-vector/1983/23
Differential Revision: https://reviews.llvm.org/D123703
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/CMakeLists.txt
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index b5e9f25c710e2..39c6353cadffb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3e9ad30cd8bc3..76daf9e387c2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -13,6 +13,7 @@
#ifndef VECTOR_OPS
#define VECTOR_OPS
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
@@ -2539,4 +2540,139 @@ def Vector_ScanOp :
let hasVerifier = 1;
}
+def Vector_YieldOp : Vector_Op<"yield", [
+ NoSideEffect, ReturnLike, Terminator]> {
+ let summary = "Terminates and yields values from vector regions.";
+ let description = [{
+ "vector.yield" yields an SSA value from the Vector dialect op region and
+ terminates the regions. The semantics of how the values are yielded is
+ defined by the parent operation.
+ If "vector.yield" has any operands, the operands must correspond to the
+ parent operation's results.
+ If the parent operation defines no value the vector.yield may be omitted
+ when printing the region.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [
+ OpBuilder<(ins), [{ /* nothing to do */ }]>,
+ ];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
+def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>,
+ SingleBlockImplicitTerminator<"vector::YieldOp">,
+ RecursiveSideEffects]> {
+ let summary = "Executes operations in the associated region on lane #0 of a"
+ "GPU SIMT warp";
+ let description = [{
+ `warp_execute_on_lane_0` is an operation used to bridge the gap between
+ vector programming and GPU SIMT programming model. It allows to trivially
+ convert a region of vector code meant to run on a GPU warp into a valid SIMT
+ region and then allows incremental transformation to distribute vector
+ operations on the SIMT lane.
+
+ Any code present in the region would only be executed on first lane
+ based on the `laneid` operand. The `laneid` operand is an integer ID between
+ [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in
+ a warp.
+
+ Operands are vector values distributed on all lanes that may be used by
+ the single lane execution. The matching region argument is a vector of all
+ the values of those lanes available to the single active lane. The
+ distributed dimension is implicit based on the shape of the operand and
+ argument. In the future this may be described by an affine map.
+
+ Return values are distributed on all lanes using laneId as index. The
+ vector is distributed based on the shape ratio between the vector type of
+ the yield and the result type.
+ If the shapes are the same this means the value is broadcasted to all lanes.
+ In the future the distribution can be made more explicit using affine_maps
+ and will support having multiple Ids.
+
+ Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy
+ between lane0 and the lanes of the warp. When distributing a vector
+ from lane0 to all the lanes, the data are distributed in a block cyclic way.
+
+ During lowering values passed as operands and return value need to be
+ visible to
diff erent lanes within the warp. This would usually be done by
+ going through memory.
+
+ The region is *not* isolated from above. For values coming from the parent
+ region not going through operands only the lane 0 value will be accesible so
+ it generally only make sense for uniform values.
+
+ Example:
+ ```
+ vector.warp_execute_on_lane_0 (%laneid)[32] {
+ ...
+ }
+ ```
+
+ This may be lowered to an scf.if region as below:
+ ```
+ %cnd = arith.cmpi eq, %laneid, %c0 : index
+ scf.if %cnd {
+ ...
+ }
+ ```
+
+ When the region has operands and/or return values:
+ ```
+ %0 = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%v0 : vector<4xi32>) -> (vector<1xf32>) {
+ ^bb0(%arg0 : vector<128xi32>) :
+ ...
+ vector.yield %1 : vector<32xf32>
+ }
+ ```
+
+ values at the region boundary would go through memory:
+ ```
+ %tmp0 = memreg.alloc() : memref<32xf32, 3>
+ %tmp1 = memreg.alloc() : memref<32xf32, 3>
+ %cnd = arith.cmpi eq, %laneid, %c0 : index
+ vector.store %v0, %tmp0[%laneid] : memref<32xf32>, vector<1xf32>
+ warp_sync
+ scf.if %cnd {
+ %arg0 = vector.load %tmp0[%c0] : memref<32xf32>, vector<32xf32>
+ ...
+ vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32>
+ }
+ warp_sync
+ %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32>
+ ```
+
+ }];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+ let arguments = (ins Index:$laneid, I64Attr:$warp_size,
+ Variadic<AnyType>:$args);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$warpRegion);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$laneid, "int64_t":$warpSize)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+ "int64_t":$warpSize)>,
+ // `blockArgTypes` are
diff erent than `args` types as they are they
+ // represent all the `args` instances visibile to lane 0. Therefore we need
+ // to explicit pass the type.
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+ "int64_t":$warpSize, "ValueRange":$args,
+ "TypeRange":$blockArgTypes)>
+ ];
+
+ let extraClassDeclaration = [{
+ bool isDefinedOutsideOfRegion(Value value) {
+ return !getRegion().isAncestor(value.getParentRegion());
+ }
+ }];
+}
+
#endif // VECTOR_OPS
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 59647539d10e9..17380bd049db5 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVector
LINK_LIBS PUBLIC
MLIRArithmetic
+ MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRDialectUtils
MLIRIR
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 940f9262f1472..af174601da570 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4589,6 +4589,184 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+//===----------------------------------------------------------------------===//
+// WarpExecuteOnLane0Op
+//===----------------------------------------------------------------------===//
+
+void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
+ p << "(" << getLaneid() << ")";
+
+ SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
+ auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
+ p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
+
+ if (!getArgs().empty())
+ p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
+ if (!getResults().empty())
+ p << " -> (" << getResults().getTypes() << ')';
+ p << " ";
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/!getResults().empty());
+ p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
+}
+
+ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Create the region.
+ result.regions.reserve(1);
+ Region *warpRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::UnresolvedOperand laneId;
+
+ // Parse predicate operand.
+ if (parser.parseLParen() || parser.parseRegionArgument(laneId) ||
+ parser.parseRParen())
+ return failure();
+
+ int64_t warpSize;
+ if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
+ parser.parseRSquare())
+ return failure();
+ result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
+ builder.getContext())),
+ builder.getI64IntegerAttr(warpSize));
+
+ if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
+ return failure();
+
+ llvm::SMLoc inputsOperandsLoc;
+ SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
+ SmallVector<Type> inputTypes;
+ if (succeeded(parser.parseOptionalKeyword("args"))) {
+ if (parser.parseLParen())
+ return failure();
+
+ inputsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(inputsOperands) ||
+ parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+ return failure();
+ }
+ if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
+ result.operands))
+ return failure();
+
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ // Parse the region.
+ if (parser.parseRegion(*warpRegion, /*arguments=*/{},
+ /*argTypes=*/{}))
+ return failure();
+ WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+void WarpExecuteOnLane0Op::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index.hasValue()) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // The warp region is always executed
+ regions.push_back(RegionSuccessor(&getWarpRegion()));
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTypes, Value laneId,
+ int64_t warpSize) {
+ build(builder, result, resultTypes, laneId, warpSize,
+ /*operands=*/llvm::None, /*argTypes=*/llvm::None);
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+ TypeRange resultTypes, Value laneId,
+ int64_t warpSize, ValueRange args,
+ TypeRange blockArgTypes) {
+ result.addOperands(laneId);
+ result.addAttribute(getAttributeNames()[0],
+ builder.getI64IntegerAttr(warpSize));
+ result.addTypes(resultTypes);
+ result.addOperands(args);
+ assert(args.size() == blockArgTypes.size());
+ OpBuilder::InsertionGuard guard(builder);
+ Region *warpRegion = result.addRegion();
+ Block *block = builder.createBlock(warpRegion);
+ for (auto it : llvm::zip(blockArgTypes, args))
+ block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
+}
+
+/// Helper check if the distributed vector type is consistent with the expanded
+/// type and distributed size.
+static LogicalResult verifyDistributedType(Type expanded, Type distributed,
+ int64_t warpSize, Operation *op) {
+ // If the types matches there is no distribution.
+ if (expanded == distributed)
+ return success();
+ auto expandedVecType = expanded.dyn_cast<VectorType>();
+ auto distributedVecType = distributed.dyn_cast<VectorType>();
+ if (!expandedVecType || !distributedVecType)
+ return op->emitOpError("expected vector type for distributed operands.");
+ if (expandedVecType.getRank() != distributedVecType.getRank() ||
+ expandedVecType.getElementType() != distributedVecType.getElementType())
+ return op->emitOpError(
+ "expected distributed vectors to have same rank and element type.");
+ bool foundDistributedDim = false;
+ for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
+ if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
+ continue;
+ if (expandedVecType.getDimSize(i) ==
+ distributedVecType.getDimSize(i) * warpSize) {
+ if (foundDistributedDim)
+ return op->emitOpError()
+ << "expected only one dimension to be distributed from "
+ << expandedVecType << " to " << distributedVecType;
+ foundDistributedDim = true;
+ continue;
+ }
+ return op->emitOpError() << "incompatible distribution dimensions from "
+ << expandedVecType << " to " << distributedVecType;
+ }
+ return success();
+}
+
+LogicalResult WarpExecuteOnLane0Op::verify() {
+ if (getArgs().size() != getWarpRegion().getNumArguments())
+ return emitOpError(
+ "expected same number op arguments and block arguments.");
+ auto yield =
+ cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+ if (yield.getNumOperands() != getNumResults())
+ return emitOpError(
+ "expected same number of yield operands and return values.");
+ int64_t warpSize = getWarpSize();
+ for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
+ if (failed(verifyDistributedType(std::get<0>(it).getType(),
+ std::get<1>(it).getType(), warpSize,
+ getOperation())))
+ return failure();
+ }
+ for (auto it : llvm::zip(yield.getOperands(), getResults())) {
+ if (failed(verifyDistributedType(std::get<0>(it).getType(),
+ std::get<1>(it).getType(), warpSize,
+ getOperation())))
+ return failure();
+ }
+ return success();
+}
+
+bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
+ return succeeded(
+ verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index f60d2b103b882..e3e01b993df36 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1528,3 +1528,78 @@ func @invalid_splat(%v : f32) {
vector.splat %v : memref<8xf32>
return
}
+
+// -----
+
+func @warp_wrong_num_outputs(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected same number of yield operands and return values.}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) {
+ }
+ return
+}
+
+// -----
+
+func @warp_wrong_num_inputs(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected same number op arguments and block arguments.}}
+ vector.warp_execute_on_lane_0(%laneid)[64] {
+ ^bb0(%arg0 : vector<128xi32>) :
+ }
+ return
+}
+
+// -----
+
+func @warp_wrong_return_distribution(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<4xi32>) {
+ %0 = arith.constant dense<2>: vector<128xi32>
+ vector.yield %0 : vector<128xi32>
+ }
+ return
+}
+
+
+// -----
+
+func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op incompatible distribution dimensions from 'vector<128xi32>' to 'vector<4xi32>'}}
+ vector.warp_execute_on_lane_0(%laneid)[64]
+ args(%v0 : vector<4xi32>) {
+ ^bb0(%arg0 : vector<128xi32>) :
+ }
+ return
+}
+
+// -----
+
+func @warp_2_distributed_dims(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
+ %0 = arith.constant dense<2>: vector<128x128xi32>
+ vector.yield %0 : vector<128x128xi32>
+ }
+ return
+}
+
+// -----
+
+func @warp_mismatch_rank(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
+ %0 = arith.constant dense<2>: vector<128xi32>
+ vector.yield %0 : vector<128xi32>
+ }
+ return
+}
+
+// -----
+
+func @warp_mismatch_rank(%laneid: index) {
+ // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
+ %0 = arith.constant dense<2>: vector<128xi32>
+ vector.yield %0 : vector<128xi32>
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 43b38efb242eb..3db28eb912def 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -745,3 +745,30 @@ func @vector_splat_0d(%a: f32) -> vector<f32> {
%0 = vector.splat %a : vector<f32>
return %0 : vector<f32>
}
+
+// CHECK-LABEL: func @warp_execute_on_lane_0(
+func @warp_execute_on_lane_0(%laneid: index) {
+// CHECK-NEXT: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+// CHECK-NEXT: }
+ }
+// CHECK-NEXT: return
+ return
+}
+
+// CHECK-LABEL: func @warp_operand_result(
+func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) {
+// CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) {
+ %2 = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%v0 : vector<4xi32>) -> (vector<4xi32>) {
+ ^bb0(%arg0 : vector<128xi32>) :
+ %0 = arith.constant dense<2>: vector<128xi32>
+ %1 = arith.addi %arg0, %0 : vector<128xi32>
+// CHECK: vector.yield %{{.*}} : vector<128xi32>
+ vector.yield %1 : vector<128xi32>
+// CHECK-NEXT: }
+ }
+ return %2 : vector<4xi32>
+}
+
+
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 656a089082ab3..65096ccd68917 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2963,6 +2963,7 @@ cc_library(
deps = [
":ArithmeticDialect",
":ArithmeticUtils",
+ ":ControlFlowInterfaces",
":DialectUtils",
":IR",
":MemRefDialect",
@@ -7275,6 +7276,7 @@ td_library(
srcs = ["include/mlir/Dialect/Vector/IR/VectorOps.td"],
includes = ["include"],
deps = [
+ ":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
More information about the Mlir-commits
mailing list