[Mlir-commits] [mlir] 5571cc7 - [mlir][linalg] Add helpers for linalg.tiled_loop [nfc].
Alexander Belyaev
llvmlistbot at llvm.org
Tue Apr 6 09:17:50 PDT 2021
Author: Alexander Belyaev
Date: 2021-04-06T18:17:37+02:00
New Revision: 5571cc7deed6dc01f4764adbbf3d668866f22173
URL: https://github.com/llvm/llvm-project/commit/5571cc7deed6dc01f4764adbbf3d668866f22173
DIFF: https://github.com/llvm/llvm-project/commit/5571cc7deed6dc01f4764adbbf3d668866f22173.diff
LOG: [mlir][linalg] Add helpers for linalg.tiled_loop [nfc].
Differential Revision: https://reviews.llvm.org/D99968
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a2c3bb7b0279..b0a93f36ab75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,9 +584,33 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
];
let extraClassDeclaration = [{
+ /// Number of operands controlling the loop: lbs, ubs, steps
+ unsigned getNumControlOperands() { return 3 * getNumLoops(); }
+
ValueRange getInductionVars() {
return getBody()->getArguments();
}
+
+ /// Result that corresponds to the `outputs` argument of tensor type.
+ OpResult getTiedOpResult(OpOperand& opOperand) {
+ // No result can correspond to a memref argument.
+ if (opOperand.get().getType().isa<MemRefType>()) return OpResult();
+
+ // Check whether the operand index is in bounds of `outputs()` arg.
+ int operandIndex = opOperand.getOperandNumber();
+ int outputIndexStart =
+ getNumControlOperands() + inputs().size();
+ int outputIndexEnd = outputIndexStart + outputs().size();
+ if (operandIndex < outputIndexStart || operandIndex >= outputIndexEnd)
+ return OpResult();
+
+ // Count tensor arguments in `outputs` to compute the result index.
+ int tensorId = -1;
+ for (int i = outputIndexStart; i <= operandIndex; ++i)
+ tensorId += getOperand(i).getType().isa<RankedTensorType>();
+ return getOperation()->getResult(tensorId);
+ }
+
unsigned getNumLoops() { return step().size(); }
}];
More information about the Mlir-commits
mailing list