[llvm-branch-commits] [mlir] 8b124c1 - [mlir][sparse] adjust output shape inference to new tensor abstraction
Aart Bik via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 5 15:36:48 PST 2021
Author: Aart Bik
Date: 2021-01-05T15:31:39-08:00
New Revision: 8b124c19f52cb8ed0236b602df56787553e1e1b6
URL: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6
DIFF: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6.diff
LOG: [mlir][sparse] adjust output shape inference to new tensor abstraction
Nicolas changed the tensor abstraction so that every output has
its own shape definition. This simplifies the "inference" that
was used in the sparse compiler.
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D94119
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_2d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index a6b7277e47e3..ed81d5e24805 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -538,15 +538,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
// Find lower and upper bound in current dimension.
Value up;
if (shape[d] == TensorType::kDynamicSize) {
- // For the output tensor, we may need to infer the upper bound.
- // For all others, we look at the incoming argument.
- if (t == numInputs && !op.getNumInitTensors()) {
- up = codegen.sizes[i];
- assert(up); // TODO: what else?
- } else {
- Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0];
- up = rewriter.create<DimOp>(loc, arg, d);
- }
+ Value arg = t < numInputs ? op.getInput(t) : op.getOutput(0);
+ up = rewriter.create<DimOp>(loc, arg, d);
args.push_back(up);
} else {
up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir
index 6612a723f23d..9bb68ca91089 100644
--- a/mlir/test/Dialect/Linalg/sparse_2d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir
@@ -1139,19 +1139,19 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_2:.*]] = constant 999 : index
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
-// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK: %[[VAL_5:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
-// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
-// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref<?x?xf64>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
+// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK: %[[VAL_9:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
+// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
+// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index
-// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] {
-// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xf64>
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64
// CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<?x?xf64>
// CHECK: }
More information about the llvm-branch-commits
mailing list