[llvm-branch-commits] [mlir] 8b124c1 - [mlir][sparse] adjust output shape inference to new tensor abstraction

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 5 15:36:48 PST 2021


Author: Aart Bik
Date: 2021-01-05T15:31:39-08:00
New Revision: 8b124c19f52cb8ed0236b602df56787553e1e1b6

URL: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6
DIFF: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6.diff

LOG: [mlir][sparse] adjust output shape inference to new tensor abstraction

Nicolas changed the tensor abstraction so that every output has
its own shape definition. This simplifies the "inference" that
was used in the sparse compiler.

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D94119

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index a6b7277e47e3..ed81d5e24805 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -538,15 +538,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
       // Find lower and upper bound in current dimension.
       Value up;
       if (shape[d] == TensorType::kDynamicSize) {
-        // For the output tensor, we may need to infer the upper bound.
-        // For all others, we look at the incoming argument.
-        if (t == numInputs && !op.getNumInitTensors()) {
-          up = codegen.sizes[i];
-          assert(up); // TODO: what else?
-        } else {
-          Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0];
-          up = rewriter.create<DimOp>(loc, arg, d);
-        }
+        Value arg = t < numInputs ? op.getInput(t) : op.getOutput(0);
+        up = rewriter.create<DimOp>(loc, arg, d);
         args.push_back(up);
       } else {
         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);

diff  --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir
index 6612a723f23d..9bb68ca91089 100644
--- a/mlir/test/Dialect/Linalg/sparse_2d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir
@@ -1139,19 +1139,19 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
 // CHECK:           %[[VAL_2:.*]] = constant 999 : index
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_5:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
 // CHECK:           %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
-// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
-// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref<?x?xf64>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xf64>
+// CHECK:           %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_9:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64>
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK:             %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<?xindex>
 // CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK:               %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xf64>
+// CHECK:               %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:               %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xf64>
 // CHECK:               %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64
 // CHECK:               store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<?x?xf64>
 // CHECK:             }


        


More information about the llvm-branch-commits mailing list