[Mlir-commits] [mlir] 705068c - [mlir][linalg] Support for using output values in TC definitions.
Hanhan Wang
llvmlistbot at llvm.org
Wed Feb 24 11:38:23 PST 2021
Author: Hanhan Wang
Date: 2021-02-24T11:37:45-08:00
New Revision: 705068cb8c4d86c798c4134f0a332f4a45c7df04
URL: https://github.com/llvm/llvm-project/commit/705068cb8c4d86c798c4134f0a332f4a45c7df04
DIFF: https://github.com/llvm/llvm-project/commit/705068cb8c4d86c798c4134f0a332f4a45c7df04.diff
LOG: [mlir][linalg] Support for using output values in TC definitions.
This will allow us to define select(pred, in, out) for TC ops, which is useful
for pooling ops.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D97312
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index cd72aced29af..399eb634c1b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -1,13 +1,13 @@
ods_def<MatmulOp>
implements_interface<LinalgContractionOpInterface> :
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
- C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+ C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
}
ods_def<MatmulColumnMajorOp>
implements_interface<LinalgContractionOpInterface> :
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
- C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
+ C(n, m) = std_addf<k>(C(n, m), std_mulf(A(k, m), B(n, k)));
}
ods_def<MatmulI8I8I32Op>
@@ -15,169 +15,170 @@ implements_interface<LinalgContractionOpInterface> :
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
- C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
}
ods_def<MatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
}
ods_def<MatmulI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(std_muli(A(m, k), B(k, n)));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
}
ods_def<MatvecOp>
implements_interface<LinalgContractionOpInterface> :
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
- x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
+ x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
}
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
}
ods_def<MatvecI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
}
ods_def<MatvecI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_muli(A(m, n), y(n)));
+ x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
}
ods_def<VecmatOp>
implements_interface<LinalgContractionOpInterface> :
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
- x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
+ x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
}
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
}
ods_def<VecmatI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
}
-
ods_def<VecmatI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_muli(y(m), A(m, n)));
+ x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
}
ods_def<DotOp>
implements_interface<LinalgContractionOpInterface> :
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
- C() = std_addf<m>(std_mulf(A(m), B(m)));
+ C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
}
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
- C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
}
ods_def<DotI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
- C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
}
-
ods_def<DotI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
- C() = std_addi<m>(std_muli(A(m), B(m)));
+ C() = std_addi<m>(C(), std_muli(A(m), B(m)));
}
-
ods_def<BatchMatmulOp>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
- C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
+ C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
}
ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ C(b, m, n) =
+ std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
}
ods_def<BatchMatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ C(b, m, n) =
+ std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
}
ods_def<BatchMatmulI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_muli(A(b, m, k), B(b, k, n)));
+ C(b, m, n) = std_addi<k>(C(b, m, n), std_muli(A(b, m, k), B(b, k, n)));
}
ods_def<ConvWOp>:
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
- O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
+ O(w) = std_addf<kw>(O(w), std_mulf(I(w + kw), K(kw)));
}
ods_def<ConvNWCOp>:
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
- O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
+ O(n, w, f) = std_addf<kw>(O(n, w, f), std_mulf(I(n, w + kw, c), K(f, kw, c)));
}
ods_def<ConvNCWOp>:
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
- O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
+ O(n, f, w) = std_addf<kw>(O(n, f, w), std_mulf(I(n, c, w + kw), K(f, c, kw)));
}
ods_def<ConvHWOp>:
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
- O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
+ O(h, w) = std_addf<kh, kw>(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
}
ods_def<ConvNHWCOp>:
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
- O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
- I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+ O(n, h, w, f) = std_addf<kh, kw>(
+ O(n, h, w, f), std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
}
ods_def<ConvNCHWOp>:
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
- O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
- I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+ O(n, f, h, w) = std_addf<kh, kw>(
+ O(n, f, h, w), std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
}
ods_def<ConvDHWOp>:
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
- O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+ O(d, h, w) = std_addf<kd, kh, kw>(
+ O(d, h, w), std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
}
ods_def<ConvNDHWCOp>:
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
- O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
- I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+ O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+ O(n, d, h, w, f),
+ std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
}
ods_def<ConvNCDHWOp>:
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
- O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+ O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+ O(n, f, d, h, w),
+ std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
}
ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
@@ -209,8 +210,10 @@ order of (`N`, `OH`, `OW`, `C`, `KH`, `KW`).
Note: this op only supports channel multiplier == 1.
"""
{
- O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
- I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)));
+ O(n, oh, ow, c) = std_addf<kh, kw>(
+ O(n, oh, ow, c),
+ std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+ K(kh, kw, c)));
}
ods_def<ConvInputNWCFilterWCFOp>:
@@ -226,6 +229,7 @@ order of (`N`, `W`, `F`, `KW`, `C`).
"""
{
O(n, w, f) = std_addf<kw>(
+ O(n, w, f),
std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f)));
}
@@ -242,6 +246,7 @@ order of (`N`, `F`, `W`, `KW`, `C`).
"""
{
O(n, f, w) = std_addf<kw>(
+ O(n, f, w),
std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f)));
}
@@ -257,10 +262,10 @@ The indexing maps for these three tensors contain 7 dimensions, following the
order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`).
"""
{
- O(n, h, w, f) =
- std_addf<kh, kw>(std_mulf(I(n, h * strides[0] + kh * dilations[0],
- w * strides[1] + kw * dilations[1], c),
- K(kh, kw, c, f)));
+ O(n, h, w, f) = std_addf<kh, kw>(
+ O(n, h, w, f), std_mulf(I(n, h * strides[0] + kh * dilations[0],
+ w * strides[1] + kw * dilations[1], c),
+ K(kh, kw, c, f)));
}
ods_def<ConvInputNCHWFilterHWCFOp>:
@@ -277,10 +282,10 @@ The indexing maps for these three tensors contain 7 dimensions, following the
order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`).
"""
{
- O(n, f, h, w) =
- std_addf<kh, kw>(std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
- w * strides[1] + kw * dilations[1]),
- K(kh, kw, c, f)));
+ O(n, f, h, w) = std_addf<kh, kw>(
+ O(n, f, h, w), std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
+ w * strides[1] + kw * dilations[1]),
+ K(kh, kw, c, f)));
}
ods_def<ConvInputNDHWCFilterDHWCFOp>:
@@ -297,11 +302,11 @@ The indexing maps for these three tensors contain 9 dimensions, following the
order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`).
"""
{
- O(n, d, h, w, f) =
- std_addf<kd, kh, kw>(std_mulf(I(n, d * strides[0] + kd * dilations[0],
- h * strides[1] + kh * dilations[1],
- w * strides[2] + kw * dilations[2], c),
- K(kd, kh, kw, c, f)));
+ O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+ O(n, d, h, w, f), std_mulf(I(n, d * strides[0] + kd * dilations[0],
+ h * strides[1] + kh * dilations[1],
+ w * strides[2] + kw * dilations[2], c),
+ K(kd, kh, kw, c, f)));
}
ods_def<ConvInputNCDHWFilterDHWCFOp>:
@@ -318,8 +323,9 @@ The indexing maps for these three tensors contain 9 dimensions, following the
order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`).
"""
{
- O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(n, c, d * strides[0] + kd * dilations[0],
- h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]),
- K(kd, kh, kw, c, f)));
+ O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+ O(n, f, d, h, w), std_mulf(I(n, c, d * strides[0] + kd * dilations[0],
+ h * strides[1] + kh * dilations[1],
+ w * strides[2] + kw * dilations[2]),
+ K(kd, kh, kw, c, f)));
}
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index b197ba3da65d..7eea58869c1a 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -31,7 +31,7 @@
//
ods_def<Test1Op> :
def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
- C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+ C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
}
// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
@@ -55,7 +55,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
//
ods_def<Test2Op> :
def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
- C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+ C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
}
// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
@@ -79,7 +79,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
//
ods_def<Test3Op> :
def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
- C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+ C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(k, n)));
}
// Test attribute definitions
@@ -115,7 +115,7 @@ attr(
array_attr : f32[],
optional_attr? : f32
) {
- C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+ C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(k, n)));
}
// Test attribute usage in affine expressions
@@ -157,7 +157,7 @@ It has two inputs.
It has one output.
"""
{
- C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+ C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
}
// Test attribute builder
@@ -172,5 +172,17 @@ ods_def<Test7Op>:
def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
attr(attr_a: f32, attr_b: 4xi32)
{
- C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+ C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
+}
+
+// Test output arg order.
+// IMPL-LABEL: void Test8Op::regionBuilder(Block &block, ValueRange captures) {
+// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
+// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
+// IMPL: Value [[e:.*]] = std_subf([[d]], [[c]]);
+// IMPL: (linalg_yield(ValueRange{ [[e]] }));
+ods_def<Test8Op>:
+def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+{
+ C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 4f57322c8be6..52fd9fbcd904 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1087,6 +1087,9 @@ class TCParser {
/// Parses a tensor use.
struct ComprehensionParsingState {
+ /// The number of operands (which includes inputs and outputs) in a
+ /// comprehension.
+ size_t numArgs;
AffineDimList dims;
SmallVector<std::unique_ptr<Expression>, 4> expressions;
llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
@@ -1510,11 +1513,6 @@ LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
}
- // If this op is a reduction, it's first argument is the `currentDefinition`
- // tensor use.
- if (!reductionDims.empty())
- expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
- LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
auto parseExpr = [&]() -> LogicalResult {
std::unique_ptr<Expression> e;
@@ -1619,7 +1617,8 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
auto tensorIter = registeredTensors.find(use.tensorId);
assert(tensorIter != registeredTensors.end() && "unregistered tensor");
auto &tensor = tensorIter->getValue();
- if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
+ if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 &&
+ tensor.indexingMap.getResults() != use.indexingMap.getResults()) {
LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
(void)parser.emitError(
"Unexpected multi-read of a tensor with
diff erent accesses");
@@ -1630,6 +1629,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
tensor.indexingMap = use.indexingMap;
state.orderedTensorArgs[use] = tensor.index;
});
+ state.numArgs = seenDefs.size();
if (failed)
return failure();
@@ -2004,8 +2004,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
// Finally put everything together.
os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
- attrList, state.orderedTensorArgs.size(), attrBuilder,
- attrMethods);
+ attrList, state.numArgs, attrBuilder, attrMethods);
}
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2203,7 +2202,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
std::string mapsStr;
llvm::raw_string_ostream mapsStringStream(mapsStr);
- SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
+ SmallVector<TensorUse, 4> orderedUses(state.numArgs);
for (const auto &it : state.orderedTensorArgs)
orderedUses[it.second] = it.first;
@@ -2286,7 +2285,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
/// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
ComprehensionParsingState &state) {
- unsigned count = state.orderedTensorArgs.size();
+ unsigned count = state.numArgs;
llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
@@ -2326,7 +2325,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
std::string valueHandleStr;
llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
llvm::interleaveComma(
- state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
+ llvm::seq<int>(0, state.numArgs), valueHandleStringStream, [&](auto) {
valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
idx++;
});
More information about the Mlir-commits
mailing list