[Mlir-commits] [mlir] 86bf62f - [mlir][sparse] improve push_back type checking, printing, parsing
Aart Bik
llvmlistbot at llvm.org
Tue Oct 18 09:55:36 PDT 2022
Author: Aart Bik
Date: 2022-10-18T09:55:25-07:00
New Revision: 86bf62fa6435e8a3911c4060bd6a143116d9c52c
URL: https://github.com/llvm/llvm-project/commit/86bf62fa6435e8a3911c4060bd6a143116d9c52c
DIFF: https://github.com/llvm/llvm-project/commit/86bf62fa6435e8a3911c4060bd6a143116d9c52c.diff
LOG: [mlir][sparse] improve push_back type checking, printing, parsing
Rationale:
Enforces type consistency on parsed and generated IR.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D136132
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 14e5af040384..16a056523641 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -238,7 +238,11 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
let hasVerifier = 1;
}
-def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
+def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
+ [TypesMatchWith<"value type matches element type of inBuffer",
+ "inBuffer", "value",
+ "$_self.cast<ShapedType>().getElementType()">,
+ AllTypesMatch<["inBuffer", "outBuffer"]>]>,
Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
@@ -263,19 +267,18 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
Example:
```mlir
- %r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index}
- : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+ %r = sparse_tensor.push_back %bufferSizes, %buffer, %val
+ {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
```
```mlir
%r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
- {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+ {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
```
}];
let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
" `,` $value attr-dict `:` type($bufferSizes) `,`"
- " type($inBuffer) `,` type($value) `to`"
- " type($outBuffer)";
+ " type($inBuffer) `,` type($value)";
}
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index dd81c523ce32..31c6ad5a0262 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -22,7 +22,7 @@
// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
// CHECK: return %[[M]] : memref<?xf64>
func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
return %0 : memref<?xf64>
}
@@ -40,7 +40,7 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
// CHECK: return %[[B]] : memref<?xf64>
func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
return %0 : memref<?xf64>
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 7ac5179b87bf..21aa1603a45d 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -124,6 +124,14 @@ func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: ind
// -----
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f32) -> memref<?xf64> {
+ // expected-error at +1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f32
+ return %0 : memref<?xf64>
+}
+
+// -----
+
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
// expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index ca3f8843825e..8c8ab6be7037 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -136,10 +136,10 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %a
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
// CHECK: return %[[D]]
func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
return %0 : memref<?xf64>
}
@@ -149,10 +149,10 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
// CHECK: return %[[D]]
func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
return %0 : memref<?xf64>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
index ff57bfee527d..90d1b3760695 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
@@ -16,8 +16,8 @@ module {
%buffer = memref.alloc(%c1) : memref<?xf32>
memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
- %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
- %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+ %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
+ %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
// CHECK: ( 2 )
%sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
More information about the Mlir-commits
mailing list