[Mlir-commits] [mlir] 96130b5 - [mlir][spirv] Support size-1 vector/tensor constant during conversion
Lei Zhang
llvmlistbot at llvm.org
Tue Dec 14 12:58:17 PST 2021
Author: Lei Zhang
Date: 2021-12-14T15:58:08-05:00
New Revision: 96130b5dc7f600e0c8f5e850ddddfaba952bba34
URL: https://github.com/llvm/llvm-project/commit/96130b5dc7f600e0c8f5e850ddddfaba952bba34
DIFF: https://github.com/llvm/llvm-project/commit/96130b5dc7f600e0c8f5e850ddddfaba952bba34.diff
LOG: [mlir][spirv] Support size-1 vector/tensor constant during conversion
Reviewed By: ThomasRaoux, mravishankar
Differential Revision: https://reviews.llvm.org/D115518
Added:
Modified:
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index da62168576aea..4bd6a4230f9e5 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -273,7 +273,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto srcType = constOp.getType().dyn_cast<ShapedType>();
- if (!srcType)
+ if (!srcType || srcType.getNumElements() == 1)
return failure();
// arith.constant should only have vector or tenor types.
@@ -358,16 +358,25 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Type srcType = constOp.getType();
+ if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
+ if (shapedType.getNumElements() != 1)
+ return failure();
+ srcType = shapedType.getElementType();
+ }
if (!srcType.isIntOrIndexOrFloat())
return failure();
+ Attribute cstAttr = constOp.getValue();
+ if (cstAttr.getType().isa<ShapedType>())
+ cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
+
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
// Floating-point types.
if (srcType.isa<FloatType>()) {
- auto srcAttr = constOp.getValue().cast<FloatAttr>();
+ auto srcAttr = cstAttr.cast<FloatAttr>();
auto dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
@@ -386,7 +395,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
if (srcType.isInteger(1)) {
// arith.constant can use 0/1 instead of true/false for i1 values. We need
// to handle that here.
- auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter);
+ auto dstAttr = convertBoolAttr(cstAttr, rewriter);
if (!dstAttr)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -395,7 +404,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
// IndexType or IntegerType. Index values are converted to 32-bit integer
// values when converting to SPIR-V.
- auto srcAttr = constOp.getValue().cast<IntegerAttr>();
+ auto srcAttr = cstAttr.cast<IntegerAttr>();
auto dstAttr =
convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
if (!dstAttr)
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 291c7fdd77b1c..18b7680e8e1b8 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -446,6 +446,17 @@ func @constant_64bit() {
return
}
+// CHECK-LABEL: @constant_size1
+func @constant_size1() {
+ // CHECK: spv.Constant true
+ %0 = arith.constant dense<true> : tensor<1xi1>
+ // CHECK: spv.Constant 4 : i64
+ %1 = arith.constant dense<4> : vector<1xi64>
+ // CHECK: spv.Constant 5.000000e+00 : f64
+ %2 = arith.constant dense<5.0> : tensor<1xf64>
+ return
+}
+
} // end module
// -----
@@ -485,6 +496,15 @@ func @constant_64bit() {
return
}
+// CHECK-LABEL: @constant_size1
+func @constant_size1() {
+ // CHECK: spv.Constant 4 : i32
+ %0 = arith.constant dense<4> : vector<1xi64>
+ // CHECK: spv.Constant 5.000000e+00 : f32
+ %1 = arith.constant dense<5.0> : tensor<1xf64>
+ return
+}
+
// CHECK-LABEL: @corner_cases
func @corner_cases() {
// CHECK: %{{.*}} = spv.Constant -1 : i32
More information about the Mlir-commits
mailing list