[Mlir-commits] [mlir] ec495b5 - [mlir][sparse] Folding operations that try to insert zero into an all-zero sparse tensor
Peiming Liu
llvmlistbot at llvm.org
Thu Aug 25 10:00:15 PDT 2022
Author: Peiming Liu
Date: 2022-08-25T17:00:04Z
New Revision: ec495b53f8b2ddf02626fd3d12d1ec7d36bd2bca
URL: https://github.com/llvm/llvm-project/commit/ec495b53f8b2ddf02626fd3d12d1ec7d36bd2bca
DIFF: https://github.com/llvm/llvm-project/commit/ec495b53f8b2ddf02626fd3d12d1ec7d36bd2bca.diff
LOG: [mlir][sparse] Folding operations that try to insert zero into an all-zero sparse tensor
The operations to fill zero into newly allocated sparse tensor are redundant, plus it failed
to lowering the test cases provided in the patch as well.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D132500
Added:
mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 0d5c7efa20172..9adfacebda0d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -126,9 +126,15 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
!isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op))
return failure();
auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
- if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType))
- return failure();
+ // Yielding zero on newly allocated (all-zero) sparse tensors can be
+ // optimized out directly (regardless of dynamic or static size).
+ if (getSparseTensorEncoding(outputType)) {
+ rewriter.replaceOp(op, op.getOutputOperand(0)->get());
+ return success();
+ }
// Incorporate zero value into allocation copy.
+ if (!outputType.hasStaticShape())
+ return failure();
Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
AllocTensorOp a =
op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
new file mode 100644
index 0000000000000..f2812cd7fb673
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+// CHECK-LABEL: func.func @fill_zero_after_alloc
+// CHECK-SAME: %[[TMP_arg0:.*]]: !llvm.ptr<i8>,
+// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>
+// CHECK: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
+// CHECK: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
+// CHECK: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK: %[[TMP_false:.*]] = arith.constant false
+// CHECK: %[[TMP_true:.*]] = arith.constant true
+// CHECK: %[[TMP_c100:.*]] = arith.constant 100 : index
+// CHECK: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK: %[[TMP_0:.*]] = memref.alloca() : memref<2xi8>
+// CHECK: %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref<?xi8>
+// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK: %[[TMP_2:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[TMP_4:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[TMP_5:.*]] = memref.cast %[[TMP_4]] : memref<2xindex> to memref<?xindex>
+// CHECK: memref.store %[[TMP_c0]], %[[TMP_4]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: memref.store %[[TMP_c1]], %[[TMP_4]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[TMP_6:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[TMP_6]])
+// CHECK: %[[TMP_8:.*]] = call @sparseDimSize(%[[TMP_7]], %[[TMP_c1]])
+// CHECK: %[[TMP_9:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xf64>
+// CHECK: %[[TMP_10:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xi1>
+// CHECK: %[[TMP_11:.*]] = memref.alloc(%[[TMP_8]]) : memref<?xindex>
+// CHECK: linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_9]] : memref<?xf64>)
+// CHECK: linalg.fill ins(%[[TMP_false]] : i1) outs(%[[TMP_10]] : memref<?xi1>)
+// CHECK: %[[TMP_12:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c0]])
+// CHECK: %[[TMP_13:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c0]])
+// CHECK: %[[TMP_14:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c1]])
+// CHECK: %[[TMP_15:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c1]])
+// CHECK: %[[TMP_16:.*]] = call @sparseValuesF64(%[[TMP_arg0]])
+// CHECK: %[[TMP_17:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c0]])
+// CHECK: %[[TMP_18:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c0]])
+// CHECK: %[[TMP_19:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c1]])
+// CHECK: %[[TMP_20:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c1]])
+// CHECK: %[[TMP_21:.*]] = call @sparseValuesF64(%[[TMP_arg1]])
+// CHECK: %[[TMP_22:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[TMP_23:.*]] = memref.cast %[[TMP_22]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[TMP_24:.*]] = memref.load %[[TMP_12]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_12]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_24]] to %[[TMP_25]] step %[[TMP_c1]] {
+// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_13]][%[[TMP_arg2]]] : memref<?xindex>
+// CHECK: memref.store %[[TMP_26]], %[[TMP_22]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_14]][%[[TMP_arg2]]] : memref<?xindex>
+// CHECK: %[[TMP_28:.*]] = arith.addi %[[TMP_arg2]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_29:.*]] = memref.load %[[TMP_14]][%[[TMP_28]]] : memref<?xindex>
+// CHECK: %[[TMP_30:.*]] = memref.load %[[TMP_17]][%[[TMP_c0]]] : memref<?xindex>
+// CHECK: %[[TMP_31:.*]] = memref.load %[[TMP_17]][%[[TMP_c1]]] : memref<?xindex>
+// CHECK: %[[TMP_32:.*]]:3 = scf.while (%[[TMP_arg3:.*]] = %[[TMP_27]], %[[TMP_arg4:.*]] = %[[TMP_30]], %[[TMP_arg5:.*]] = %[[TMP_c0]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[TMP_33:.*]] = arith.cmpi ult, %[[TMP_arg3]], %[[TMP_29]] : index
+// CHECK: %[[TMP_34:.*]] = arith.cmpi ult, %[[TMP_arg4]], %[[TMP_31]] : index
+// CHECK: %[[TMP_35:.*]] = arith.andi %[[TMP_33]], %[[TMP_34]] : i1
+// CHECK: scf.condition(%[[TMP_35]]) %[[TMP_arg3]], %[[TMP_arg4]], %[[TMP_arg5]] : index, index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[TMP_arg3:.*]]: index, %[[TMP_arg4:.*]]: index, %[[TMP_arg5:.*]]: index):
+// CHECK: %[[TMP_33:.*]] = memref.load %[[TMP_15]][%[[TMP_arg3]]] : memref<?xindex>
+// CHECK: %[[TMP_34:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_35:.*]] = arith.cmpi ult, %[[TMP_34]], %[[TMP_33]] : index
+// CHECK: %[[TMP_36:.*]] = arith.select %[[TMP_35]], %[[TMP_34]], %[[TMP_33]] : index
+// CHECK: %[[TMP_37:.*]] = arith.cmpi eq, %[[TMP_33]], %[[TMP_36]] : index
+// CHECK: %[[TMP_38:.*]] = arith.cmpi eq, %[[TMP_34]], %[[TMP_36]] : index
+// CHECK: %[[TMP_39:.*]] = arith.andi %[[TMP_37]], %[[TMP_38]] : i1
+// CHECK: %[[TMP_40:.*]] = scf.if %[[TMP_39]] -> (index) {
+// CHECK: %[[TMP_45:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref<?xf64>
+// CHECK: %[[TMP_46:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref<?xindex>
+// CHECK: %[[TMP_47:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_48:.*]] = memref.load %[[TMP_19]][%[[TMP_47]]] : memref<?xindex>
+// CHECK: %[[TMP_49:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_46]] to %[[TMP_48]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) {
+// CHECK: %[[TMP_50:.*]] = memref.load %[[TMP_20]][%[[TMP_arg6]]] : memref<?xindex>
+// CHECK: %[[TMP_51:.*]] = memref.load %[[TMP_9]][%[[TMP_50]]] : memref<?xf64>
+// CHECK: %[[TMP_52:.*]] = memref.load %[[TMP_21]][%[[TMP_arg6]]] : memref<?xf64>
+// CHECK: %[[TMP_53:.*]] = arith.mulf %[[TMP_45]], %[[TMP_52]] : f64
+// CHECK: %[[TMP_54:.*]] = arith.addf %[[TMP_51]], %[[TMP_53]] : f64
+// CHECK: %[[TMP_55:.*]] = memref.load %[[TMP_10]][%[[TMP_50]]] : memref<?xi1>
+// CHECK: %[[TMP_56:.*]] = arith.cmpi eq, %[[TMP_55]], %[[TMP_false]] : i1
+// CHECK: %[[TMP_57:.*]] = scf.if %[[TMP_56]] -> (index) {
+// CHECK: memref.store %[[TMP_true]], %[[TMP_10]][%[[TMP_50]]] : memref<?xi1>
+// CHECK: memref.store %[[TMP_50]], %[[TMP_11]][%[[TMP_arg7]]] : memref<?xindex>
+// CHECK: %[[TMP_58:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index
+// CHECK: scf.yield %[[TMP_58]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[TMP_arg7]] : index
+// CHECK: }
+// CHECK: memref.store %[[TMP_54]], %[[TMP_9]][%[[TMP_50]]] : memref<?xf64>
+// CHECK: scf.yield %[[TMP_57]] : index
+// CHECK: }
+// CHECK: scf.yield %[[TMP_49]] : index
+// CHECK: } else {
+// CHECK: scf.yield %[[TMP_arg5]] : index
+// CHECK: }
+// CHECK: %[[TMP_41:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_42:.*]] = arith.select %[[TMP_37]], %[[TMP_41]], %[[TMP_arg3]] : index
+// CHECK: %[[TMP_43:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index
+// CHECK: %[[TMP_44:.*]] = arith.select %[[TMP_38]], %[[TMP_43]], %[[TMP_arg4]] : index
+// CHECK: scf.yield %[[TMP_42]], %[[TMP_44]], %[[TMP_40]] : index, index, index
+// CHECK: }
+// CHECK: func.call @expInsertF64(%[[TMP_7]], %[[TMP_23]], %[[TMP_9]], %[[TMP_10]], %[[TMP_11]], %[[TMP_32]]#2)
+// CHECK: }
+// CHECK: memref.dealloc %[[TMP_9]] : memref<?xf64>
+// CHECK: memref.dealloc %[[TMP_10]] : memref<?xi1>
+// CHECK: memref.dealloc %[[TMP_11]] : memref<?xindex>
+// CHECK: call @endInsert(%[[TMP_7]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: return %[[TMP_7]] : !llvm.ptr<i8>
+func.func @fill_zero_after_alloc(%arg0: tensor<100x100xf64, #DCSR>,
+ %arg1: tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> {
+ %0 = bufferization.alloc_tensor() : tensor<100x100xf64, #DCSR>
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = linalg.fill ins(%cst : f64)
+ outs(%0 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x100xf64, #DCSR>, tensor<100x100xf64, #DCSR>)
+ outs(%1 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR>
+ return %2 : tensor<100x100xf64, #DCSR>
+}
More information about the Mlir-commits
mailing list