Look at block-job-change command: we have to specify both 'id' to chose
the job to operate on and 'type' for QAPI union be parsed. But for user
this looks redundant: when we specify 'id', QEMU should be able to get
corresponding job's type.

This commit brings such a possibility: just specify some Enum type as
'discriminator', and define a function somewhere with prototype

bool YourUnionType_mapper(YourUnionType *, EnumType *out, Error **errp)

Further commits will use this functionality to upgrade block-job-change
interface and introduce some new interfaces.

Signed-off-by: Vladimir Sementsov-Ogievskiy <vsement...@yandex-team.ru>
---
 scripts/qapi/introspect.py |  5 +++-
 scripts/qapi/schema.py     | 50 +++++++++++++++++++++++---------------
 scripts/qapi/types.py      |  3 ++-
 scripts/qapi/visit.py      | 43 ++++++++++++++++++++++++++------
 4 files changed, 73 insertions(+), 28 deletions(-)

diff --git a/scripts/qapi/introspect.py b/scripts/qapi/introspect.py
index 67c7d89aae..04d8d424f5 100644
--- a/scripts/qapi/introspect.py
+++ b/scripts/qapi/introspect.py
@@ -336,7 +336,10 @@ def visit_object_type_flat(self, name: str, info: 
Optional[QAPISourceInfo],
             'members': [self._gen_object_member(m) for m in members]
         }
         if variants:
-            obj['tag'] = variants.tag_member.name
+            if variants.tag_member:
+                obj['tag'] = variants.tag_member.name
+            else:
+                obj['discriminator-type'] = variants.tag_type.name
             obj['variants'] = [self._gen_variant(v) for v in variants.variants]
         self._gen_tree(name, 'object', obj, ifcond, features)
 
diff --git a/scripts/qapi/schema.py b/scripts/qapi/schema.py
index 8ba5665bc6..0efe8d815c 100644
--- a/scripts/qapi/schema.py
+++ b/scripts/qapi/schema.py
@@ -585,16 +585,16 @@ def visit(self, visitor):
 
 
 class QAPISchemaVariants:
-    def __init__(self, tag_name, info, tag_member, variants):
+    def __init__(self, discriminator, info, tag_member, variants):
         # Unions pass tag_name but not tag_member.
         # Alternates pass tag_member but not tag_name.
         # After check(), tag_member is always set.
-        assert bool(tag_member) != bool(tag_name)
-        assert (isinstance(tag_name, str) or
+        assert bool(tag_member) != bool(discriminator)
+        assert (isinstance(discriminator, str) or
                 isinstance(tag_member, QAPISchemaObjectTypeMember))
         for v in variants:
             assert isinstance(v, QAPISchemaVariant)
-        self._tag_name = tag_name
+        self.discriminator = discriminator
         self.info = info
         self.tag_member = tag_member
         self.variants = variants
@@ -604,16 +604,24 @@ def set_defined_in(self, name):
             v.set_defined_in(name)
 
     def check(self, schema, seen):
-        if self._tag_name:      # union
-            self.tag_member = seen.get(c_name(self._tag_name))
+        self.tag_type = None
+
+        if self.discriminator:      # assume union with type discriminator
+            self.tag_type = schema.lookup_type(self.discriminator)
+
+        # union with discriminator field
+        if self.discriminator and not self.tag_type:
+            tag_name = self.discriminator
+            self.tag_member = seen.get(c_name(tag_name))
+            self.tag_type = self.tag_member.type
             base = "'base'"
             # Pointing to the base type when not implicit would be
             # nice, but we don't know it here
-            if not self.tag_member or self._tag_name != self.tag_member.name:
+            if not self.tag_member or tag_name != self.tag_member.name:
                 raise QAPISemError(
                     self.info,
                     "discriminator '%s' is not a member of %s"
-                    % (self._tag_name, base))
+                    % (tag_name, base))
             # Here we do:
             base_type = schema.lookup_type(self.tag_member.defined_in)
             assert base_type
@@ -623,29 +631,33 @@ def check(self, schema, seen):
                 raise QAPISemError(
                     self.info,
                     "discriminator member '%s' of %s must be of enum type"
-                    % (self._tag_name, base))
+                    % (tag_name, base))
             if self.tag_member.optional:
                 raise QAPISemError(
                     self.info,
                     "discriminator member '%s' of %s must not be optional"
-                    % (self._tag_name, base))
+                    % (tag_name, base))
             if self.tag_member.ifcond.is_present():
                 raise QAPISemError(
                     self.info,
                     "discriminator member '%s' of %s must not be conditional"
-                    % (self._tag_name, base))
-        else:                   # alternate
+                    % (tag_name, base))
+        elif not self.tag_type:  # alternate
             assert isinstance(self.tag_member.type, QAPISchemaEnumType)
             assert not self.tag_member.optional
             assert not self.tag_member.ifcond.is_present()
-        if self._tag_name:      # union
+
+        if self.tag_type:      # union
             # branches that are not explicitly covered get an empty type
             cases = {v.name for v in self.variants}
-            for m in self.tag_member.type.members:
+            for m in self.tag_type.members:
                 if m.name not in cases:
                     v = QAPISchemaVariant(m.name, self.info,
                                           'q_empty', m.ifcond)
-                    v.set_defined_in(self.tag_member.defined_in)
+                    if self.tag_member:
+                        v.set_defined_in(self.tag_member.defined_in)
+                    else:
+                        v.set_defined_in(self.tag_type.name)
                     self.variants.append(v)
         if not self.variants:
             raise QAPISemError(self.info, "union has no branches")
@@ -654,11 +666,11 @@ def check(self, schema, seen):
             # Union names must match enum values; alternate names are
             # checked separately. Use 'seen' to tell the two apart.
             if seen:
-                if v.name not in self.tag_member.type.member_names():
+                if v.name not in self.tag_type.member_names():
                     raise QAPISemError(
                         self.info,
                         "branch '%s' is not a value of %s"
-                        % (v.name, self.tag_member.type.describe()))
+                        % (v.name, self.tag_type.describe()))
                 if not isinstance(v.type, QAPISchemaObjectType):
                     raise QAPISemError(
                         self.info,
@@ -1116,7 +1128,7 @@ def _make_variant(self, case, typ, ifcond, info):
     def _def_union_type(self, expr: QAPIExpression):
         name = expr['union']
         base = expr['base']
-        tag_name = expr['discriminator']
+        discriminator = expr['discriminator']
         data = expr['data']
         assert isinstance(data, dict)
         info = expr.info
@@ -1136,7 +1148,7 @@ def _def_union_type(self, expr: QAPIExpression):
             QAPISchemaObjectType(name, info, expr.doc, ifcond, features,
                                  base, members,
                                  QAPISchemaVariants(
-                                     tag_name, info, None, variants)))
+                                     discriminator, info, None, variants)))
 
     def _def_alternate_type(self, expr: QAPIExpression):
         name = expr['alternate']
diff --git a/scripts/qapi/types.py b/scripts/qapi/types.py
index c39d054d2c..13021e3f82 100644
--- a/scripts/qapi/types.py
+++ b/scripts/qapi/types.py
@@ -230,7 +230,8 @@ def gen_variants(variants: QAPISchemaVariants) -> str:
     ret = mcgen('''
     union { /* union tag is @%(c_name)s */
 ''',
-                c_name=c_name(variants.tag_member.name))
+                c_name=c_name(variants.tag_member.name if variants.tag_member
+                              else variants.tag_type.name))
 
     for var in variants.variants:
         if var.type.name == 'q_empty':
diff --git a/scripts/qapi/visit.py b/scripts/qapi/visit.py
index c56ea4d724..fbe7ae071d 100644
--- a/scripts/qapi/visit.py
+++ b/scripts/qapi/visit.py
@@ -60,6 +60,13 @@ def gen_visit_members_decl(name: str) -> str:
                  c_name=c_name(name))
 
 
+def gen_visit_members_tag_mapper_decl(name: str, tag_type: str) -> str:
+    return mcgen('''
+
+bool %(c_name)s_mapper(%(c_name)s *obj, %(tag_type)s *tag, Error **errp);
+                 ''', c_name=c_name(name), tag_type=c_name(tag_type))
+
+
 def gen_visit_object_members(name: str,
                              base: Optional[QAPISchemaObjectType],
                              members: List[QAPISchemaObjectTypeMember],
@@ -83,6 +90,12 @@ def gen_visit_object_members(name: str,
             ret += memb.ifcond.gen_endif()
     ret += sep
 
+    if variants:
+        ret += mcgen('''
+    %(c_name)s tag;
+''',
+                     c_name=c_name(variants.tag_type.name))
+
     if base:
         ret += mcgen('''
     if (!visit_type_%(c_type)s_members(v, (%(c_type)s *)obj, errp)) {
@@ -132,17 +145,29 @@ def gen_visit_object_members(name: str,
         ret += memb.ifcond.gen_endif()
 
     if variants:
-        tag_member = variants.tag_member
-        assert isinstance(tag_member.type, QAPISchemaEnumType)
+        tag_type = variants.tag_type
+        assert isinstance(tag_type, QAPISchemaEnumType)
 
-        ret += mcgen('''
-    switch (obj->%(c_name)s) {
+        if variants.tag_member:
+            ret += mcgen('''
+    tag = obj->%(c_name)s;
 ''',
-                     c_name=c_name(tag_member.name))
+                         c_name=c_name(variants.tag_member.name))
+        else:
+            ret += mcgen('''
+    if (!%(c_name)s_mapper(obj, &tag, errp)) {
+        return false;
+    }
+''',
+                         c_name=c_name(name))
+
+        ret += mcgen('''
+    switch (tag) {
+''')
 
         for var in variants.variants:
-            case_str = c_enum_const(tag_member.type.name, var.name,
-                                    tag_member.type.prefix)
+            case_str = c_enum_const(tag_type.name, var.name,
+                                    tag_type.prefix)
             ret += var.ifcond.gen_if()
             if var.type.name == 'q_empty':
                 # valid variant and nothing to do
@@ -400,6 +425,10 @@ def visit_object_type(self,
             return
         with ifcontext(ifcond, self._genh, self._genc):
             self._genh.add(gen_visit_members_decl(name))
+            if variants and variants.tag_type:
+                self._genh.add(
+                    gen_visit_members_tag_mapper_decl(name,
+                                                      variants.tag_type.name))
             self._genc.add(gen_visit_object_members(name, base,
                                                     members, variants))
             # TODO Worth changing the visitor signature, so we could
-- 
2.34.1


Reply via email to