[Mlir-commits] [mlir] 96130b5 - [mlir][spirv] Support size-1 vector/tensor constant during conversion

Lei Zhang llvmlistbot at llvm.org
Tue Dec 14 12:58:17 PST 2021


Author: Lei Zhang
Date: 2021-12-14T15:58:08-05:00
New Revision: 96130b5dc7f600e0c8f5e850ddddfaba952bba34

URL: https://github.com/llvm/llvm-project/commit/96130b5dc7f600e0c8f5e850ddddfaba952bba34
DIFF: https://github.com/llvm/llvm-project/commit/96130b5dc7f600e0c8f5e850ddddfaba952bba34.diff

LOG: [mlir][spirv] Support size-1 vector/tensor constant during conversion

Reviewed By: ThomasRaoux, mravishankar

Differential Revision: https://reviews.llvm.org/D115518

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index da62168576aea..4bd6a4230f9e5 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -273,7 +273,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
     arith::ConstantOp constOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto srcType = constOp.getType().dyn_cast<ShapedType>();
-  if (!srcType)
+  if (!srcType || srcType.getNumElements() == 1)
     return failure();
 
   // arith.constant should only have vector or tenor types.
@@ -358,16 +358,25 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
     arith::ConstantOp constOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Type srcType = constOp.getType();
+  if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
+    if (shapedType.getNumElements() != 1)
+      return failure();
+    srcType = shapedType.getElementType();
+  }
   if (!srcType.isIntOrIndexOrFloat())
     return failure();
 
+  Attribute cstAttr = constOp.getValue();
+  if (cstAttr.getType().isa<ShapedType>())
+    cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
+
   Type dstType = getTypeConverter()->convertType(srcType);
   if (!dstType)
     return failure();
 
   // Floating-point types.
   if (srcType.isa<FloatType>()) {
-    auto srcAttr = constOp.getValue().cast<FloatAttr>();
+    auto srcAttr = cstAttr.cast<FloatAttr>();
     auto dstAttr = srcAttr;
 
     // Floating-point types not supported in the target environment are all
@@ -386,7 +395,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
   if (srcType.isInteger(1)) {
     // arith.constant can use 0/1 instead of true/false for i1 values. We need
     // to handle that here.
-    auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter);
+    auto dstAttr = convertBoolAttr(cstAttr, rewriter);
     if (!dstAttr)
       return failure();
     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -395,7 +404,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 
   // IndexType or IntegerType. Index values are converted to 32-bit integer
   // values when converting to SPIR-V.
-  auto srcAttr = constOp.getValue().cast<IntegerAttr>();
+  auto srcAttr = cstAttr.cast<IntegerAttr>();
   auto dstAttr =
       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
   if (!dstAttr)

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 291c7fdd77b1c..18b7680e8e1b8 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -446,6 +446,17 @@ func @constant_64bit() {
   return
 }
 
+// CHECK-LABEL: @constant_size1
+func @constant_size1() {
+  // CHECK: spv.Constant true
+  %0 = arith.constant dense<true> : tensor<1xi1>
+  // CHECK: spv.Constant 4 : i64
+  %1 = arith.constant dense<4> : vector<1xi64>
+  // CHECK: spv.Constant 5.000000e+00 : f64
+  %2 = arith.constant dense<5.0> : tensor<1xf64>
+  return
+}
+
 } // end module
 
 // -----
@@ -485,6 +496,15 @@ func @constant_64bit() {
   return
 }
 
+// CHECK-LABEL: @constant_size1
+func @constant_size1() {
+  // CHECK: spv.Constant 4 : i32
+  %0 = arith.constant dense<4> : vector<1xi64>
+  // CHECK: spv.Constant 5.000000e+00 : f32
+  %1 = arith.constant dense<5.0> : tensor<1xf64>
+  return
+}
+
 // CHECK-LABEL: @corner_cases
 func @corner_cases() {
   // CHECK: %{{.*}} = spv.Constant -1 : i32


        


More information about the Mlir-commits mailing list