[Mlir-commits] [mlir] 1b31b50 - [MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp.
    Lei Zhang 
    llvmlistbot at llvm.org
       
    Mon Oct  5 14:05:16 PDT 2020
    
    
  
Author: ergawy
Date: 2020-10-05T17:04:55-04:00
New Revision: 1b31b50d384b5f25221ac268ef781d26f5beacc1
URL: https://github.com/llvm/llvm-project/commit/1b31b50d384b5f25221ac268ef781d26f5beacc1
DIFF: https://github.com/llvm/llvm-project/commit/1b31b50d384b5f25221ac268ef781d26f5beacc1.diff
LOG: [MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp.
Adds support for SPIR-V composite speciailization constants to spv._reference_of.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D88732
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
    mlir/test/Dialect/SPIRV/structure-ops.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 0e866f02b011..c64606bc50f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -472,7 +472,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
   let summary = "Reference a specialization constant.";
 
   let description = [{
-    Specialization constant in module scope are defined using symbol names.
+    Specialization constants in module scope are defined using symbol names.
     This op generates an SSA value that can be used to refer to the symbol
     within function scope for use in ops that expect an SSA value.
     This operation has no corresponding SPIR-V instruction; it's merely used
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 363785e2b782..ad25ecb427a6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2568,17 +2568,27 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
-  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
-      SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
-                                           referenceOfOp.spec_const()));
-  if (!specConstOp) {
-    return referenceOfOp.emitOpError("expected spv.specConstant symbol");
-  }
-  if (referenceOfOp.reference().getType() !=
-      specConstOp.default_value().getType()) {
+  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
+      referenceOfOp.getParentOp(), referenceOfOp.spec_const());
+  Type constType;
+
+  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
+  if (specConstOp)
+    constType = specConstOp.default_value().getType();
+
+  auto specConstCompositeOp =
+      dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
+  if (specConstCompositeOp)
+    constType = specConstCompositeOp.type();
+
+  if (!specConstOp && !specConstCompositeOp)
+    return referenceOfOp.emitOpError(
+        "expected spv.specConstant or spv.SpecConstantComposite symbol");
+
+  if (referenceOfOp.reference().getType() != constType)
     return referenceOfOp.emitOpError("result type mismatch with the referenced "
                                      "specialization constant's type");
-  }
+
   return success();
 }
 
diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 153540ddb281..33966f8b21e9 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -187,6 +187,11 @@ class Deserializer {
     return specConstMap.lookup(id);
   }
 
+  /// Gets the composite specialization constant with the given result <id>.
+  spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
+    return specConstCompositeMap.lookup(id);
+  }
+
   /// Creates a spirv::SpecConstantOp.
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
                                            Attribute defaultValue);
@@ -461,9 +466,12 @@ class Deserializer {
   /// (and type) here. Later when it's used, we materialize the constant.
   DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
 
-  // Result <id> to variable mapping.
+  // Result <id> to spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
 
+  // Result <id> to composite spec constant mapping.
+  DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
@@ -1565,7 +1573,8 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
            << operands[0];
   }
 
-  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
+  auto resultID = operands[1];
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
 
   SmallVector<Attribute, 4> elements;
   elements.reserve(operands.size() - 2);
@@ -1574,9 +1583,10 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
     elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
   }
 
-  opBuilder.create<spirv::SpecConstantCompositeOp>(
+  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
       unknownLoc, TypeAttr::get(resultType), symName,
       opBuilder.getArrayAttr(elements));
+  specConstCompositeMap[resultID] = op;
 
   return success();
 }
@@ -2208,6 +2218,12 @@ Value Deserializer::getValue(uint32_t id) {
         opBuilder.getSymbolRefAttr(constOp.getOperation()));
     return referenceOfOp.reference();
   }
+  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, constCompositeOp.type(),
+        opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
+    return referenceOfOp.reference();
+  }
   if (auto undef = getUndefType(id)) {
     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
   }
diff  --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
index 0df930162c74..2cbfcc6d219d 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
@@ -12,6 +12,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32
   spv.specConstant @sc_float spec_id(5) = 1. : f32
 
+  // CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+  spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+
   // CHECK-LABEL: @use
   spv.func @use() -> (i32) "None" {
     // We materialize a `spv._reference_of` op at every use of a
@@ -24,6 +27,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %1 = spv.IAdd %0, %0 : i32
     spv.ReturnValue %1 : i32
   }
+
+  // CHECK-LABEL: @use
+  spv.func @use_composite() -> (i32) "None" {
+    // We materialize a `spv._reference_of` op at every use of a
+    // specialization constant in the deserializer. So two ops here.
+    // CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+    // CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32>
+    // CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+    // CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32>
+    // CHECK: spv.IAdd %[[ITM0]], %[[ITM1]]
+
+    %0 = spv._reference_of @scc : !spv.array<2 x i32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+    %2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32>
+    %3 = spv.IAdd %1, %2 : i32
+    spv.ReturnValue %3 : i32
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_f32_1 = 1.5 : f32
+  spv.specConstant @sc_f32_2 = 2.5 : f32
+  spv.specConstant @sc_f32_3 = 3.5 : f32
+
+  spv.specConstant @sc_i32_1 = 1   : i32
+
+  // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+
+  // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+
+  // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
+  spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
 }
 
 // -----
diff  --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index 765eba959a26..7bb98b92c3d2 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -496,6 +496,8 @@ spv.module Logical GLSL450 {
   spv.specConstant @sc2 = 42 : i64
   spv.specConstant @sc3 = 1.5 : f32
 
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i1, i64, f32>
+
   // CHECK-LABEL: @reference
   spv.func @reference() -> i1 "None" {
     // CHECK: spv._reference_of @sc1 : i1
@@ -503,6 +505,14 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %0 : i1
   }
 
+  // CHECK-LABEL: @reference_composite
+  spv.func @reference_composite() -> i1 "None" {
+    // CHECK: spv._reference_of @scc : !spv.struct<i1, i64, f32>
+    %0 = spv._reference_of @scc : !spv.struct<i1, i64, f32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.struct<i1, i64, f32>
+    spv.ReturnValue %1 : i1
+  }
+
   // CHECK-LABEL: @initialize
   spv.func @initialize() -> i64 "None" {
     // CHECK: spv._reference_of @sc2 : i64
@@ -534,9 +544,21 @@ func @reference_of() {
 
 // -----
 
+spv.specConstant @sc = 5 : i32
+spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+
+func @reference_of_composite() {
+  // CHECK: spv._reference_of @scc : !spv.array<1 x i32>
+  %0 = spv._reference_of @scc : !spv.array<1 x i32>
+  %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32>
+  return
+}
+
+// -----
+
 spv.module Logical GLSL450 {
   spv.func @foo() -> () "None" {
-    // expected-error @+1 {{expected spv.specConstant symbol}}
+    // expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}}
     %0 = spv._reference_of @sc : i32
     spv.Return
   }
@@ -555,6 +577,18 @@ spv.module Logical GLSL450 {
 
 // -----
 
+spv.module Logical GLSL450 {
+  spv.specConstant @sc = 42 : i32
+  spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+  spv.func @foo() -> () "None" {
+    // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}}
+    %0 = spv._reference_of @scc : f32
+    spv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.specConstant
 //===----------------------------------------------------------------------===//
        
    
    
More information about the Mlir-commits
mailing list