[Mlir-commits] [mlir] 73431a4 - [mlir][spirv] Consolidate std.constant to spv.constant conversions

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:32 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: 73431a492badf1b01ac4c60e193d7e7374bb63ff

URL: https://github.com/llvm/llvm-project/commit/73431a492badf1b01ac4c60e193d7e7374bb63ff
DIFF: https://github.com/llvm/llvm-project/commit/73431a492badf1b01ac4c60e193d7e7374bb63ff.diff

LOG: [mlir][spirv] Consolidate std.constant to spv.constant conversions

This commit merges the DRR pattern for std.constant to spv.constant
conversion into the C++ OpConversionPattern. This allows us to have
remove the DRR pattern file. Along the way, this commit enhanced
std.constant to spv.constant conversion to consider type conversions,
which means converting the underlying attributes if necessary.

Differential Revision: https://reviews.llvm.org/D76246

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

Removed: 
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
index 308f1b0074ed..6d940eaf024e 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
@@ -1,7 +1,3 @@
-set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
-mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
-add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
-
 add_mlir_conversion_library(MLIRStandardToSPIRVTransforms
   ConvertStandardToSPIRV.cpp
   ConvertStandardToSPIRVPass.cpp
@@ -10,9 +6,6 @@ add_mlir_conversion_library(MLIRStandardToSPIRVTransforms
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
-
-  DEPENDS
-  MLIRStandardToSPIRVIncGen
   )
 
 target_link_libraries(MLIRStandardToSPIRVTransforms

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 5b6243fa74e9..69ef69d1de65 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -18,6 +18,9 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "std-to-spirv-pattern"
 
 using namespace mlir;
 
@@ -34,6 +37,66 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Converts the given `srcAttr` into a boolean attribute if it holds a integral
+/// value. Returns null attribute if conversion fails.
+static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
+  if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
+    return boolAttr;
+  if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
+    return builder.getBoolAttr(intAttr.getValue().getBoolValue());
+  return BoolAttr();
+}
+
+/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
+/// Returns null attribute if conversion fails.
+static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
+                                      Builder builder) {
+  // If the source number uses less active bits than the target bitwidth, then
+  // it should be safe to convert.
+  if (srcAttr.getValue().isIntN(dstType.getWidth()))
+    return builder.getIntegerAttr(dstType, srcAttr.getInt());
+
+  // XXX: Try again by interpreting the source number as a signed value.
+  // Although integers in the standard dialect are signless, they can represent
+  // a signed number. It's the operation decides how to interpret. This is
+  // dangerous, but it seems there is no good way of handling this if we still
+  // want to change the bitwidth. Emit a message at least.
+  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
+    auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
+    LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
+                            << dstAttr << "' for type '" << dstType << "'\n");
+    return dstAttr;
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
+                          << "' illegal: cannot fit into target type '"
+                          << dstType << "'\n");
+  return IntegerAttr();
+}
+
+/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
+/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
+static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
+                                  Builder builder) {
+  // Only support converting to float for now.
+  if (!dstType.isF32())
+    return FloatAttr();
+
+  // Try to convert the source floating-point number to single precision.
+  APFloat dstVal = srcAttr.getValue();
+  bool losesInfo = false;
+  APFloat::opStatus status =
+      dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
+  if (status != APFloat::opOK || losesInfo) {
+    LLVM_DEBUG(llvm::dbgs()
+               << srcAttr << " illegal: cannot fit into converted type '"
+               << dstType << "'\n");
+    return FloatAttr();
+  }
+
+  return builder.getF32FloatAttr(dstVal.convertToFloat());
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -97,7 +160,7 @@ class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
 
   LogicalResult
-  matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -107,7 +170,7 @@ class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
 
   LogicalResult
-  matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -211,31 +274,84 @@ class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
-    ConstantOp constCompositeOp, ArrayRef<Value> operands,
+    ConstantOp constOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  auto compositeType =
-      constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
-  if (!compositeType)
+  auto srcType = constOp.getType().dyn_cast<ShapedType>();
+  if (!srcType)
     return failure();
 
-  auto spirvCompositeType = typeConverter.convertType(compositeType);
-  if (!spirvCompositeType)
+  // std.constant should only have vector or tenor types.
+  assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>());
+
+  auto dstType = typeConverter.convertType(srcType);
+  if (!dstType)
     return failure();
 
-  auto linearizedElements =
-      constCompositeOp.value().dyn_cast<DenseElementsAttr>();
-  if (!linearizedElements)
+  auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
+  ShapedType dstAttrType = dstElementsAttr.getType();
+  if (!dstElementsAttr)
     return failure();
 
-  // If composite type has rank greater than one, then perform linearization.
-  if (compositeType.getRank() > 1) {
-    auto linearizedType = RankedTensorType::get(compositeType.getNumElements(),
-                                                compositeType.getElementType());
-    linearizedElements = linearizedElements.reshape(linearizedType);
+  // If the composite type has more than one dimensions, perform linearization.
+  if (srcType.getRank() > 1) {
+    if (srcType.isa<RankedTensorType>()) {
+      dstAttrType = RankedTensorType::get(srcType.getNumElements(),
+                                          srcType.getElementType());
+      dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+    } else {
+      // TODO(antiagainst): add support for large vectors.
+      return failure();
+    }
+  }
+
+  Type srcElemType = srcType.getElementType();
+  Type dstElemType;
+  // Tensor types are converted to SPIR-V array types; vector types are
+  // converted to SPIR-V vector/array types.
+  if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
+    dstElemType = arrayType.getElementType();
+  else
+    dstElemType = dstType.cast<VectorType>().getElementType();
+
+  // If the source and destination element types are 
diff erent, perform
+  // attribute conversion.
+  if (srcElemType != dstElemType) {
+    SmallVector<Attribute, 8> elements;
+    if (srcElemType.isa<FloatType>()) {
+      for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
+        FloatAttr dstAttr = convertFloatAttr(
+            srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
+        if (!dstAttr)
+          return failure();
+        elements.push_back(dstAttr);
+      }
+    } else if (srcElemType.isInteger(1)) {
+      return failure();
+    } else {
+      for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
+        IntegerAttr dstAttr =
+            convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
+                               dstElemType.cast<IntegerType>(), rewriter);
+        if (!dstAttr)
+          return failure();
+        elements.push_back(dstAttr);
+      }
+    }
+
+    // Unfortunately, we cannot use dialect-specific types for element
+    // attributes; element attributes only works with standard types. So we need
+    // to prepare another converted standard types for the destination elements
+    // attribute.
+    if (dstAttrType.isa<RankedTensorType>())
+      dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
+    else
+      dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+
+    dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
   }
 
-  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
-      constCompositeOp, spirvCompositeType, linearizedElements);
+  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
+                                                 dstElementsAttr);
   return success();
 }
 
@@ -244,32 +360,52 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
-    ConstantOp constIndexOp, ArrayRef<Value> operands,
+    ConstantOp constOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (!constIndexOp.getResult().getType().isa<IndexType>()) {
+  Type srcType = constOp.getType();
+  if (!srcType.isIntOrIndexOrFloat())
     return failure();
-  }
-  // The attribute has index type which is not directly supported in
-  // SPIR-V. Get the integer value and create a new IntegerAttr.
-  auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
-  if (!constAttr) {
+
+  Type dstType = typeConverter.convertType(srcType);
+  if (!dstType)
     return failure();
+
+  // Floating-point types.
+  if (srcType.isa<FloatType>()) {
+    auto srcAttr = constOp.value().cast<FloatAttr>();
+    auto dstAttr = srcAttr;
+
+    // Floating-point types not supported in the target environment are all
+    // converted to float type.
+    if (srcType != dstType) {
+      dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
+      if (!dstAttr)
+        return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+    return success();
   }
 
-  // Use the bitwidth set in the value attribute to decide the result type
-  // of the SPIR-V constant operation since SPIR-V does not support index
-  // types.
-  auto constVal = constAttr.getValue();
-  auto constValType = constAttr.getType().dyn_cast<IndexType>();
-  if (!constValType) {
-    return failure();
+  // Bool type.
+  if (srcType.isInteger(1)) {
+    // std.constant can use 0/1 instead of true/false for i1 values. We need to
+    // handle that here.
+    auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
+    if (!dstAttr)
+      return failure();
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
+    return success();
   }
-  auto spirvConstType =
-      typeConverter.convertType(constIndexOp.getResult().getType());
-  auto spirvConstVal =
-      rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
-  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
-                                                 spirvConstVal);
+
+  // IndexType or IntegerType. Index values are converted to 32-bit integer
+  // values when converting to SPIR-V.
+  auto srcAttr = constOp.value().cast<IntegerAttr>();
+  auto dstAttr =
+      convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
+  if (!dstAttr)
+    return failure();
+  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
   return success();
 }
 
@@ -431,17 +567,10 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
 // Pattern population
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// Import the Standard Ops to SPIR-V Patterns.
-#include "StandardToSPIRV.cpp.inc"
-} // namespace
-
 namespace mlir {
 void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
-  // Add patterns that lower operations into SPIR-V dialect.
-  populateWithGenerated(context, &patterns);
   patterns.insert<
       BinaryOpPattern<AddFOp, spirv::FAddOp>,
       BinaryOpPattern<AddIOp, spirv::IAddOp>,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
deleted file mode 100644
index 016344e16304..000000000000
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
+++ /dev/null
@@ -1,25 +0,0 @@
-//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
-
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines Patterns to lower standard ops to SPIR-V.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD
-#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
-
-include "mlir/Dialect/StandardOps/IR/Ops.td"
-include "mlir/Dialect/SPIRV/SPIRVOps.td"
-
-// Constant Op
-// TODO(ravishankarm): Handle lowering other constant types.
-def : Pat<(ConstantOp:$result $valueAttr),
-          (SPV_ConstantOp $valueAttr),
-          [(SPV_ScalarOrVector $result)]>;
-
-#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 8a53488d33f4..cb5873a1baf0 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -280,9 +280,9 @@ module attributes {
 func @constant() {
   // CHECK: spv.constant true
   %0 = constant true
-  // CHECK: spv.constant 42 : i64
-  %1 = constant 42
-  // CHECK: spv.constant {{[0-9]*\.[0-9]*e?-?[0-9]*}} : f32
+  // CHECK: spv.constant 42 : i32
+  %1 = constant 42 : i32
+  // CHECK: spv.constant 5.000000e-01 : f32
   %2 = constant 0.5 : f32
   // CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
   %3 = constant dense<[2, 3]> : vector<2xi32>
@@ -303,6 +303,114 @@ func @constant() {
   return
 }
 
+// CHECK-LABEL: @constant_16bit
+func @constant_16bit() {
+  // CHECK: spv.constant 4 : i16
+  %0 = constant 4 : i16
+  // CHECK: spv.constant 5.000000e+00 : f16
+  %1 = constant 5.0 : f16
+  // CHECK: spv.constant dense<[2, 3]> : vector<2xi16>
+  %2 = constant dense<[2, 3]> : vector<2xi16>
+  // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16 [2]>
+  %3 = constant dense<4.0> : tensor<5xf16>
+  return
+}
+
+// CHECK-LABEL: @constant_64bit
+func @constant_64bit() {
+  // CHECK: spv.constant 4 : i64
+  %0 = constant 4 : i64
+  // CHECK: spv.constant 5.000000e+00 : f64
+  %1 = constant 5.0 : f64
+  // CHECK: spv.constant dense<[2, 3]> : vector<2xi64>
+  %2 = constant dense<[2, 3]> : vector<2xi64>
+  // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64 [8]>
+  %3 = constant dense<4.0> : tensor<5xf64>
+  return
+}
+
+} // end module
+
+// -----
+
+// Check that constants are converted to 32-bit when no special capability.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @constant_16bit
+func @constant_16bit() {
+  // CHECK: spv.constant 4 : i32
+  %0 = constant 4 : i16
+  // CHECK: spv.constant 5.000000e+00 : f32
+  %1 = constant 5.0 : f16
+  // CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
+  %2 = constant dense<[2, 3]> : vector<2xi16>
+  // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]>
+  %3 = constant dense<4.0> : tensor<5xf16>
+  // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]>
+  %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
+  return
+}
+
+// CHECK-LABEL: @constant_64bit
+func @constant_64bit() {
+  // CHECK: spv.constant 4 : i32
+  %0 = constant 4 : i64
+  // CHECK: spv.constant 5.000000e+00 : f32
+  %1 = constant 5.0 : f64
+  // CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
+  %2 = constant dense<[2, 3]> : vector<2xi64>
+  // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]>
+  %3 = constant dense<4.0> : tensor<5xf64>
+  // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]>
+  %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
+  return
+}
+
+// CHECK-LABEL: @corner_cases
+func @corner_cases() {
+  // CHECK: %{{.*}} = spv.constant -1 : i32
+  %0 = constant 4294967295  : i64 // 2^32 - 1
+  // CHECK: %{{.*}} = spv.constant 2147483647 : i32
+  %1 = constant 2147483647  : i64 // 2^31 - 1
+  // CHECK: %{{.*}} = spv.constant -2147483648 : i32
+  %2 = constant 2147483648  : i64 // 2^31
+  // CHECK: %{{.*}} = spv.constant -2147483648 : i32
+  %3 = constant -2147483648 : i64 // -2^31
+
+  // CHECK: %{{.*}} = spv.constant -1 : i32
+  %5 = constant -1 : i64
+  // CHECK: %{{.*}} = spv.constant -2 : i32
+  %6 = constant -2 : i64
+  // CHECK: %{{.*}} = spv.constant -1 : i32
+  %7 = constant -1 : index
+  // CHECK: %{{.*}} = spv.constant -2 : i32
+  %8 = constant -2 : index
+
+
+  // CHECK: spv.constant false
+  %9 = constant 0 : i1
+  // CHECK: spv.constant true
+  %10 = constant 1 : i1
+
+  return
+}
+
+// CHECK-LABEL: @unsupported_cases
+func @unsupported_cases() {
+  // CHECK: %{{.*}} = constant 4294967296 : i64
+  %0 = constant 4294967296 : i64 // 2^32
+  // CHECK: %{{.*}} = constant -2147483649 : i64
+  %1 = constant -2147483649 : i64 // -2^31 - 1
+  // CHECK: %{{.*}} = constant 1.0000000000000002 : f64
+  %2 = constant 0x3FF0000000000001 : f64 // smallest number > 1
+  return
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list