[Mlir-commits] [mlir] 0935c05 - [mlir][Vector] Add support for 0-D 'vector.shape_cast' lowering

Diego Caballero llvmlistbot at llvm.org
Thu Jun 1 15:23:27 PDT 2023


Author: Diego Caballero
Date: 2023-06-01T22:22:16Z
New Revision: 0935c0556bedc35d841103b58eff9a6e3464ffe6

URL: https://github.com/llvm/llvm-project/commit/0935c0556bedc35d841103b58eff9a6e3464ffe6
DIFF: https://github.com/llvm/llvm-project/commit/0935c0556bedc35d841103b58eff9a6e3464ffe6.diff

LOG: [mlir][Vector] Add support for 0-D 'vector.shape_cast' lowering

This PR adds support for shape casting from and to 0-D vectors.

Reviewed By: nicolasvasilache, hanchung, awarzynski

Differential Revision: https://reviews.llvm.org/D151851

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
    mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index bd9716cbca94c..f2b28cad76745 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -151,8 +151,26 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         incIdx(srcIdx, sourceVectorType, srcRank - 1);
         incIdx(resIdx, resultVectorType, resRank - 1);
       }
-      Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+
+      Value extract;
+      if (srcRank == 0) {
+        // 0-D vector special case
+        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
+        extract = rewriter.create<vector::ExtractElementOp>(
+            loc, op.getSourceVectorType().getElementType(), op.getSource());
+      } else {
+        extract =
+            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      }
+
+      if (resRank == 0) {
+        // 0-D vector special case
+        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
+        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
+      } else {
+        result =
+            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+      }
     }
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 716537ed76ff7..f233a17244ff7 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,4 +1,3 @@
-
 // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @nop_shape_cast
@@ -124,9 +123,35 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
   return %s : vector<2x1x3xf32>
 }
 
+// CHECK-LABEL:   func.func @shape_cast_0d1d(
+// CHECK-SAME:                               %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
+// CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
+// CHECK:           return %[[VAL_3]] : vector<1xf32>
+// CHECK:         }
+
+func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
+  %s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+  return %s : vector<1xf32>
+}
+
+// CHECK-LABEL:   func.func @shape_cast_1d0d(
+// CHECK-SAME:                               %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
+// CHECK:           return %[[VAL_3]] : vector<f32>
+// CHECK:         }
+
+func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
+  %s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
+  return %s : vector<f32>
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  %f = transform.structured.match ops{["func.func"]} in %module_op 
+  %f = transform.structured.match ops{["func.func"]} in %module_op
     : (!transform.any_op) -> !transform.any_op
 
   %f2 = transform.vector.lower_shape_cast %f


        


More information about the Mlir-commits mailing list