[Mlir-commits] [mlir] 705068c - [mlir][linalg] Support for using output values in TC definitions.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 24 11:38:23 PST 2021


Author: Hanhan Wang
Date: 2021-02-24T11:37:45-08:00
New Revision: 705068cb8c4d86c798c4134f0a332f4a45c7df04

URL: https://github.com/llvm/llvm-project/commit/705068cb8c4d86c798c4134f0a332f4a45c7df04
DIFF: https://github.com/llvm/llvm-project/commit/705068cb8c4d86c798c4134f0a332f4a45c7df04.diff

LOG: [mlir][linalg] Support for using output values in TC definitions.

This will allow us to define select(pred, in, out) for TC ops, which is useful
for pooling ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D97312

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index cd72aced29af..399eb634c1b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -1,13 +1,13 @@
 ods_def<MatmulOp>
 implements_interface<LinalgContractionOpInterface> :
 def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
-  C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+  C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
 }
 
 ods_def<MatmulColumnMajorOp>
 implements_interface<LinalgContractionOpInterface> :
 def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
-  C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
+  C(n, m) = std_addf<k>(C(n, m), std_mulf(A(k, m), B(n, k)));
 }
 
 ods_def<MatmulI8I8I32Op>
@@ -15,169 +15,170 @@ implements_interface<LinalgContractionOpInterface> :
 def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
   // TODO: ideally something closer to
   //   C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
-  C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+  C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
 }
 
 ods_def<MatmulI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
-  C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+  C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
 }
 
 ods_def<MatmulI32I32I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
-  C(m, n) = std_addi<k>(std_muli(A(m, k), B(k, n)));
+  C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
 }
 
 ods_def<MatvecOp>
 implements_interface<LinalgContractionOpInterface> :
 def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
-  x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
+  x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
 }
 
 ods_def<MatvecI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
-  x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+  x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
 }
 
 ods_def<MatvecI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
-  x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+  x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
 }
 
 ods_def<MatvecI32I32I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
-  x(m) = std_addi<n>(std_muli(A(m, n), y(n)));
+  x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
 }
 
 ods_def<VecmatOp>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
-  x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
+  x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
 }
 
 ods_def<VecmatI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
-  x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+  x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
 }
 
 ods_def<VecmatI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
-  x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+  x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
 }
 
-
 ods_def<VecmatI32I32I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
-  x(n) = std_addi<m>(std_muli(y(m), A(m, n)));
+  x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
 }
 
 ods_def<DotOp>
 implements_interface<LinalgContractionOpInterface> :
 def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
-  C() = std_addf<m>(std_mulf(A(m), B(m)));
+  C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
 }
 
 ods_def<DotI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
-  C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+  C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
 }
 
 ods_def<DotI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
-  C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+  C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
 }
 
-
 ods_def<DotI32I32I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
-  C() = std_addi<m>(std_muli(A(m), B(m)));
+  C() = std_addi<m>(C(), std_muli(A(m), B(m)));
 }
 
-
 ods_def<BatchMatmulOp>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
-  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
+  C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
 }
 
 ods_def<BatchMatmulI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
-  C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+  C(b, m, n) =
+      std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
 }
 
 ods_def<BatchMatmulI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
-  C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+  C(b, m, n) =
+      std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
 }
 
 
 ods_def<BatchMatmulI32I32I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) {
-  C(b, m, n) = std_addi<k>(std_muli(A(b, m, k), B(b, k, n)));
+  C(b, m, n) = std_addi<k>(C(b, m, n), std_muli(A(b, m, k), B(b, k, n)));
 }
 
 ods_def<ConvWOp>:
 def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
-  O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
+  O(w) = std_addf<kw>(O(w), std_mulf(I(w + kw), K(kw)));
 }
 
 ods_def<ConvNWCOp>:
 def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
-  O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
+  O(n, w, f) = std_addf<kw>(O(n, w, f), std_mulf(I(n, w + kw, c), K(f, kw, c)));
 }
 
 ods_def<ConvNCWOp>:
 def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
-  O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
+  O(n, f, w) = std_addf<kw>(O(n, f, w), std_mulf(I(n, c, w + kw), K(f, c, kw)));
 }
 
 ods_def<ConvHWOp>:
 def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
-  O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
+  O(h, w) = std_addf<kh, kw>(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
 }
 
 ods_def<ConvNHWCOp>:
 def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
-  O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
-    I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+  O(n, h, w, f) = std_addf<kh, kw>(
+      O(n, h, w, f), std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
 }
 
 ods_def<ConvNCHWOp>:
 def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
-  O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
-    I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+  O(n, f, h, w) = std_addf<kh, kw>(
+      O(n, f, h, w), std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
 }
 
 ods_def<ConvDHWOp>:
 def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
-  O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
-    I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+  O(d, h, w) = std_addf<kd, kh, kw>(
+      O(d, h, w), std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
 }
 
 ods_def<ConvNDHWCOp>:
 def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
-  O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
-    I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+  O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+      O(n, d, h, w, f),
+      std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
 }
 
 ods_def<ConvNCDHWOp>:
 def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
-  O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
-    I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+  O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+      O(n, f, d, h, w),
+      std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
 }
 
 ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
@@ -209,8 +210,10 @@ order of (`N`, `OH`, `OW`, `C`, `KH`, `KW`).
 Note: this op only supports channel multiplier == 1.
 """
 {
-  O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
-    I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)));
+  O(n, oh, ow, c) = std_addf<kh, kw>(
+      O(n, oh, ow, c),
+      std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+               K(kh, kw, c)));
 }
 
 ods_def<ConvInputNWCFilterWCFOp>:
@@ -226,6 +229,7 @@ order of (`N`, `W`, `F`, `KW`, `C`).
 """
 {
   O(n, w, f) = std_addf<kw>(
+      O(n, w, f),
       std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f)));
 }
 
@@ -242,6 +246,7 @@ order of (`N`, `F`, `W`, `KW`, `C`).
 """
 {
   O(n, f, w) = std_addf<kw>(
+      O(n, f, w),
       std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f)));
 }
 
@@ -257,10 +262,10 @@ The indexing maps for these three tensors contain 7 dimensions, following the
 order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`).
 """
 {
-  O(n, h, w, f) =
-      std_addf<kh, kw>(std_mulf(I(n, h * strides[0] + kh * dilations[0],
-                                  w * strides[1] + kw * dilations[1], c),
-                                K(kh, kw, c, f)));
+  O(n, h, w, f) = std_addf<kh, kw>(
+      O(n, h, w, f), std_mulf(I(n, h * strides[0] + kh * dilations[0],
+                                w * strides[1] + kw * dilations[1], c),
+                              K(kh, kw, c, f)));
 }
 
 ods_def<ConvInputNCHWFilterHWCFOp>:
@@ -277,10 +282,10 @@ The indexing maps for these three tensors contain 7 dimensions, following the
 order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`).
 """
 {
-  O(n, f, h, w) =
-      std_addf<kh, kw>(std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
-                                  w * strides[1] + kw * dilations[1]),
-                                K(kh, kw, c, f)));
+  O(n, f, h, w) = std_addf<kh, kw>(
+      O(n, f, h, w), std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
+                                w * strides[1] + kw * dilations[1]),
+                              K(kh, kw, c, f)));
 }
 
 ods_def<ConvInputNDHWCFilterDHWCFOp>:
@@ -297,11 +302,11 @@ The indexing maps for these three tensors contain 9 dimensions, following the
 order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`).
 """
 {
-  O(n, d, h, w, f) =
-      std_addf<kd, kh, kw>(std_mulf(I(n, d * strides[0] + kd * dilations[0],
-                                      h * strides[1] + kh * dilations[1],
-                                      w * strides[2] + kw * dilations[2], c),
-                                    K(kd, kh, kw, c, f)));
+  O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+      O(n, d, h, w, f), std_mulf(I(n, d * strides[0] + kd * dilations[0],
+                                   h * strides[1] + kh * dilations[1],
+                                   w * strides[2] + kw * dilations[2], c),
+                                 K(kd, kh, kw, c, f)));
 }
 
 ods_def<ConvInputNCDHWFilterDHWCFOp>:
@@ -318,8 +323,9 @@ The indexing maps for these three tensors contain 9 dimensions, following the
 order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`).
 """
 {
-  O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
-      I(n, c, d * strides[0] + kd * dilations[0],
-        h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]),
-      K(kd, kh, kw, c, f)));
+  O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+      O(n, f, d, h, w), std_mulf(I(n, c, d * strides[0] + kd * dilations[0],
+                                   h * strides[1] + kh * dilations[1],
+                                   w * strides[2] + kw * dilations[2]),
+                                 K(kd, kh, kw, c, f)));
 }

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index b197ba3da65d..7eea58869c1a 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -31,7 +31,7 @@
 //
 ods_def<Test1Op> :
 def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
-  C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+  C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
 }
 
 // ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
@@ -55,7 +55,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //
 ods_def<Test2Op> :
 def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
-  C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+  C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
 }
 
 // ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
@@ -79,7 +79,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //
 ods_def<Test3Op> :
 def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
-  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+  C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(k, n)));
 }
 
 // Test attribute definitions
@@ -115,7 +115,7 @@ attr(
   array_attr : f32[],
   optional_attr? : f32
 ) {
-  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+  C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(k, n)));
 }
 
 // Test attribute usage in affine expressions
@@ -157,7 +157,7 @@ It has two inputs.
 It has one output.
 """
 {
-  C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+  C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
 }
 
 // Test attribute builder
@@ -172,5 +172,17 @@ ods_def<Test7Op>:
 def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
      attr(attr_a: f32, attr_b: 4xi32)
 {
-  C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
+  C(m) = std_addf<k>(C(m), std_mulf(A(m, k), B(k)));
+}
+
+// Test output arg order.
+// IMPL-LABEL:  void Test8Op::regionBuilder(Block &block, ValueRange captures) {
+//       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
+//       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
+//       IMPL:  Value [[e:.*]] = std_subf([[d]], [[c]]);
+//       IMPL:  (linalg_yield(ValueRange{ [[e]] }));
+ods_def<Test8Op>:
+def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
+{
+  C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
 }

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 4f57322c8be6..52fd9fbcd904 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1087,6 +1087,9 @@ class TCParser {
 
   /// Parses a tensor use.
   struct ComprehensionParsingState {
+    /// The number of operands (which includes inputs and outputs) in a
+    /// comprehension.
+    size_t numArgs;
     AffineDimList dims;
     SmallVector<std::unique_ptr<Expression>, 4> expressions;
     llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
@@ -1510,11 +1513,6 @@ LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
       reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
   }
 
-  // If this op is a reduction, it's first argument is the `currentDefinition`
-  // tensor use.
-  if (!reductionDims.empty())
-    expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
-  LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
 
   auto parseExpr = [&]() -> LogicalResult {
     std::unique_ptr<Expression> e;
@@ -1619,7 +1617,8 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
       auto tensorIter = registeredTensors.find(use.tensorId);
       assert(tensorIter != registeredTensors.end() && "unregistered tensor");
       auto &tensor = tensorIter->getValue();
-      if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
+      if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 &&
+          tensor.indexingMap.getResults() != use.indexingMap.getResults()) {
         LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
         (void)parser.emitError(
             "Unexpected multi-read of a tensor with 
diff erent accesses");
@@ -1630,6 +1629,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
       tensor.indexingMap = use.indexingMap;
       state.orderedTensorArgs[use] = tensor.index;
     });
+  state.numArgs = seenDefs.size();
   if (failed)
     return failure();
 
@@ -2004,8 +2004,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
 
   // Finally put everything together.
   os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
-                      attrList, state.orderedTensorArgs.size(), attrBuilder,
-                      attrMethods);
+                      attrList, state.numArgs, attrBuilder, attrMethods);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2203,7 +2202,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
   std::string mapsStr;
   llvm::raw_string_ostream mapsStringStream(mapsStr);
 
-  SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
+  SmallVector<TensorUse, 4> orderedUses(state.numArgs);
   for (const auto &it : state.orderedTensorArgs)
     orderedUses[it.second] = it.first;
 
@@ -2286,7 +2285,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
                                   ComprehensionParsingState &state) {
-  unsigned count = state.orderedTensorArgs.size();
+  unsigned count = state.numArgs;
   llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
   std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
   printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
@@ -2326,7 +2325,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
   std::string valueHandleStr;
   llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
   llvm::interleaveComma(
-      state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
+      llvm::seq<int>(0, state.numArgs), valueHandleStringStream, [&](auto) {
         valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
         idx++;
       });


        


More information about the Mlir-commits mailing list