XZise has submitted this change and it was merged.

Change subject: TokenWallet: new methods and refactoring
......................................................................


TokenWallet: new methods and refactoring

TokenWallet raises KeyError if user retrieves token for a not allowed
action.

Site.preload_tokens () renamed site.get_tokens().

In addition:
- Raise Error in TokenWallet if User has no rights
- Exception changed to Error.
- Make the message more explicit:
  Action 'patrol' is not allowed for user X on Y wiki.
- Added __contains__ method.
- Added __str__ method.
- Added __repr__ method.
- Added cache for not available tokens, to avoid to request them again.
- Moved cache for available tokens in TokenWallet()
- Preload all tokens available when requested for one token.

Change-Id: Iac9567084dca017a1ac4ff07e4d0c994b51d79e5
---
M pywikibot/site.py
M tests/site_tests.py
2 files changed, 233 insertions(+), 54 deletions(-)

Approvals:
  XZise: Looks good to me, approved
  Mpaa: Looks good to me, but someone else must approve



diff --git a/pywikibot/site.py b/pywikibot/site.py
index a5c1754..53edb4c 100644
--- a/pywikibot/site.py
+++ b/pywikibot/site.py
@@ -1185,20 +1185,61 @@
 
     def __init__(self, site):
         self.site = site
-        self.site._tokens = {}
-        # TODO: Fetch that from the API with paraminfo
-        self.special_names = set(['deleteglobalaccount', 'patrol', 'rollback',
-                                  'setglobalaccountstatus', 'userrights',
-                                  'watch'])
+        self._tokens = {}
+        self.failed_cache = set()  # cache unavailable tokens.
+
+    def load_tokens(self, types, all=False):
+        """Preload one or multiple tokens."""
+        assert(self.site.logged_in())
+
+        self._tokens.setdefault(self.site.user(), {}).update(
+            self.site.get_tokens(types, all=all))
+
+        # Preload all only the first time.
+        # When all=True types is extended in site.get_tokens().
+        # Keys not recognised as tokens, are cached so they are not requested
+        # any longer.
+        if all:
+            for key in types:
+                if key not in self._tokens[self.site.user()]:
+                    self.failed_cache.add((self.site.user(), key))
 
     def __getitem__(self, key):
-        storage = self.site._tokens.setdefault(self.site.user(), {})
-        if (LV(self.site.version()) >= LV('1.24wmf19')
-                and key not in self.special_names):
-            key = 'csrf'
-        if key not in storage:
-            self.site.preload_tokens([key])
-        return storage[key]
+        assert(self.site.logged_in())
+
+        user_tokens = self._tokens.setdefault(self.site.user(), {})
+        # always preload all for users without tokens
+        failed_cache_key = (self.site.user(), key)
+
+        try:
+            key = self.site.validate_tokens([key])[0]
+        except IndexError:
+            raise Error(
+                u"Requested token '{0}' is invalid on {1} wiki."
+                .format(key, self.site))
+
+        if (key not in user_tokens and
+                failed_cache_key not in self.failed_cache):
+                    self.load_tokens([key], all=not user_tokens)
+
+        if key in user_tokens:
+            return user_tokens[key]
+        else:
+            # token not allowed for self.site.user() on self.site
+            self.failed_cache.add(failed_cache_key)
+            # to be changed back to a plain KeyError?
+            raise Error(
+                u"Action '{0}' is not allowed for user {1} on {2} wiki."
+                .format(key, self.site.user(), self.site))
+
+    def __contains__(self, key):
+        return key in self._tokens.setdefault(self.site.user(), {})
+
+    def __str__(self):
+        return self._tokens.__str__()
+
+    def __repr__(self):
+        return self._tokens.__repr__()
 
 
 class APISite(BaseSite):
@@ -1228,12 +1269,63 @@
 #    Pages; see method docs for details) --
 #
 
+    # Constants for token management.
+    # For all MediaWiki versions prior to 1.20.
+    TOKENS_0 = set(['edit',
+                    'delete',
+                    'protect',
+                    'move',
+                    'block',
+                    'unblock',
+                    'email',
+                    'import',
+                    'watch',
+                    ])
+
+    # For all MediaWiki versions, with 1.20 <= version < 1.24wmf19
+    TOKENS_1 = set(['block',
+                    'centralauth',
+                    'delete',
+                    'deleteglobalaccount',
+                    'edit',
+                    'email',
+                    'import',
+                    'move',
+                    'options',
+                    'patrol',
+                    'protect',
+                    'setglobalaccountstatus',
+                    'unblock',
+                    'watch',
+                    ])
+
+    # For all MediaWiki versions >= 1.24wmf19
+    TOKENS_2 = set(['csrf',
+                    'deleteglobalaccount',
+                    'patrol',
+                    'rollback',
+                    'setglobalaccountstatus',
+                    'userrights',
+                    'watch',
+                    ])
+
     def __init__(self, code, fam=None, user=None, sysop=None):
         """ Constructor. """
         BaseSite.__init__(self, code, fam, user, sysop)
         self._msgcache = {}
         self._loginstatus = LoginStatus.NOT_ATTEMPTED
         self._siteinfo = Siteinfo(self)
+        self.tokens = TokenWallet(self)
+
+    def __getstate__(self):
+        """ Remove token wallet before pickling. """
+        new = super(APISite, self).__getstate__()
+        del new['tokens']
+        return new
+
+    def __setstate__(self, attrs):
+        """ Restore things removed in __getstate__. """
+        super(APISite, self).__setstate__()
         self.tokens = TokenWallet(self)
 
     @staticmethod
@@ -2266,15 +2358,39 @@
                 api.update_page(page, pagedata)
                 yield page
 
-    def preload_tokens(self, types):
+    def validate_tokens(self, types):
+        """Validate if requested tokens are acceptable.
+
+        Valid tokens depend on mw version.
+
+        """
+
+        _version = LV(self.version())
+        if _version < LV('1.20'):
+            valid_types = [token for token in types if token in self.TOKENS_0]
+        elif _version < LV('1.24wmf19'):
+            valid_types = [token for token in types if token in self.TOKENS_1]
+        else:
+            valid_types = []
+            for token in types:
+                if ((token in self.TOKENS_0 or token in self.TOKENS_1) and
+                        token not in self.TOKENS_2):
+                    token = 'csrf'
+                if token in self.TOKENS_2:
+                    valid_types.append(token)
+
+        return valid_types
+
+    def get_tokens(self, types, all=False):
         """Preload one or multiple tokens.
 
         For all MediaWiki versions prior to 1.20, only one token can be
-        retrieved at once. For MediaWiki versions since 1.24wmfXXX a new token
+        retrieved at once.
+        For MediaWiki versions since 1.24wmfXXX a new token
         system was introduced which reduced the amount of tokens available.
         Most of them were merged into the 'csrf' token. If the token type in
-        the parameter is not known it'll default to the 'csrf' token. The other
-        token types available are:
+        the parameter is not known it will default to the 'csrf' token.
+        The other token types available are:
         - deleteglobalaccount
         - patrol
         - rollback
@@ -2285,34 +2401,60 @@
         @param types: the types of token (e.g., "edit", "move", "delete");
             see API documentation for full list of types
         @type  types: iterable
+        @param all: load all available tokens
+        @type all: bool
+
+        return: a dict with retrieved valid tokens.
+
         """
-        storage = self._tokens.setdefault(self.user(), {})
-        if LV(self.version()) < LV('1.20'):
-            for tokentype in types:
+
+        def warn_handler(mod, text):
+            """Filter warnings for not available tokens."""
+            return re.match(r'Action \'\w+\' is not allowed for the current 
user',
+                            text)
+
+        user_tokens = {}
+        _version = LV(self.version())
+        if _version < LV('1.20'):
+            if all:
+                types.extend(self.TOKENS_0)
+            for tokentype in self.validate_tokens(types):
                 query = api.PropertyGenerator('info',
                                               titles='Dummy page',
                                               intoken=tokentype,
                                               site=self)
+                query.request._warning_handler = warn_handler
+
                 for item in query:
                     pywikibot.debug(unicode(item), _logger)
                     if (tokentype + 'token') in item:
-                        storage[tokentype] = item[tokentype + 'token']
+                        user_tokens[tokentype] = item[tokentype + 'token']
+
         else:
-            if LV(self.version()) < LV('1.24wmf19'):
-                data = api.Request(site=self, action='tokens',
-                                   type='|'.join(types)).submit()
+            if _version < LV('1.24wmf19'):
+                if all:
+                    types.extend(self.TOKENS_1)
+                req = api.Request(site=self, action='tokens',
+                                   type='|'.join(self.validate_tokens(types)))
             else:
-                new_tokens = [token if token in self.tokens.special_names else 
'csrf'
-                              for token in types]
-                data = api.Request(site=self, action='query', meta='tokens',
-                                   type='|'.join(new_tokens)).submit()
-                if 'query' in data:
-                    data = data['query']
+                if all:
+                    types.extend(self.TOKENS_2)
+
+                req = api.Request(site=self, action='query', meta='tokens',
+                                   type='|'.join(self.validate_tokens(types)))
+
+            req._warning_handler = warn_handler
+            data = req.submit()
+
+            if 'query' in data:
+                data = data['query']
 
             if 'tokens' in data and data['tokens']:
-                storage.update(dict((key[:-5], val)
+                user_tokens = dict((key[:-5], val)
                                     for key, val in data['tokens'].items()
-                                    if val != '+\\'))
+                                    if val != '+\\')
+
+        return user_tokens
 
     @deprecated("the 'tokens' property")
     def token(self, page, tokentype):
diff --git a/tests/site_tests.py b/tests/site_tests.py
index 778d5e6..50fae12 100644
--- a/tests/site_tests.py
+++ b/tests/site_tests.py
@@ -222,28 +222,6 @@
         if a:
             self.assertEqual(a[0], mainpage)
 
-    def testTokens(self):
-        """Test ability to get page tokens."""
-        mysite = self.get_site()
-        for ttype in ("edit", "move"):  # token types for non-sysops
-            try:
-                token = self.site.tokens[ttype]
-            except KeyError:
-                raise unittest.SkipTest(
-                    "Testing '%s' token not possible with user on %s"
-                    % (ttype, self.site))
-            self.assertIsInstance(token, basestring)
-            self.assertEqual(token, mysite.tokens[ttype])
-
-    def testInvalidToken(self):
-        mysite = self.get_site()
-        if LV(mysite.version()) >= LV('1.23wmf19'):
-            # Currently with the new token API all unknown types are treated
-            # as csrf tokens, so it won't throw an error here
-            # a patch is in development: 
https://gerrit.wikimedia.org/r/#/c/159394
-            raise unittest.SkipTest('No invalid token with the new token API 
possible')
-        self.assertRaises(KeyError, lambda t: mysite.tokens[t], "invalidtype")
-
     def testPreload(self):
         """Test that preloading works."""
         mysite = self.get_site()
@@ -1035,6 +1013,64 @@
     #       and the other following methods in site.py
 
 
+class TestSiteTokens(DefaultSiteTestCase):
+
+    """Test cases for tokens in Site methods."""
+
+    user = True
+
+    def setUp(self):
+        """Store version."""
+        self.mysite = self.get_site()
+        self._version = LV(self.mysite.version())
+        self.orig_version = self.mysite.version
+
+    def tearDown(self):
+        """Restore version."""
+        self.mysite.version = self.orig_version
+
+    def test_tokens_in_mw_119(self):
+        """Test ability to get page tokens."""
+        self.mysite.version = lambda: '1.19'
+        for ttype in ("edit", "move"):  # token types for non-sysops
+            token = self.site.tokens[ttype]
+            self.assertIsInstance(token, basestring)
+            self.assertEqual(token, self.mysite.tokens[ttype])
+        # test __contains__
+        self.assertIn("edit", self.mysite.tokens)
+
+    def test_tokens_in_mw_120_124wmf18(self):
+        """Test ability to get page tokens."""
+        if self._version < LV('1.20'):
+            raise unittest.SkipTest(
+                u'Site %s version %s is too low for this tests.'
+                % (self.mysite, self._version))
+        self.mysite.version = lambda: '1.21'
+        for ttype in ("edit", "move"):  # token types for non-sysops
+            token = self.mysite.tokens[ttype]
+            self.assertIsInstance(token, basestring)
+            self.assertEqual(token, self.mysite.tokens[ttype])
+        # test __contains__
+        self.assertIn("edit", self.mysite.tokens)
+
+    def test_tokens_in_mw_124wmf19(self):
+        """Test ability to get page tokens."""
+        if self._version < LV('1.24wmf19'):
+            raise unittest.SkipTest(
+                u'Site %s version %s is too low for this tests.'
+                % (self.mysite, self._version))
+        self.mysite.version = lambda: '1.24wmf20'
+        for ttype in ("edit", "move"):  # token types for non-sysops
+            token = self.mysite.tokens[ttype]
+            self.assertIsInstance(token, basestring)
+            self.assertEqual(token, self.mysite.tokens[ttype])
+        # test __contains__
+        self.assertIn("csrf", self.mysite.tokens)
+
+    def testInvalidToken(self):
+        self.assertRaises(pywikibot.Error, lambda t: self.mysite.tokens[t], 
"invalidtype")
+
+
 class TestSiteExtensions(WikimediaDefaultSiteTestCase):
 
     """Test cases for Site extensions."""
@@ -1327,10 +1363,11 @@
         }
     }
 
-    def test_is_uploaddisabled(self):
+    def test_is_uploaddisabled_wp(self):
         site = self.get_site('wikipediatest')
         self.assertFalse(site.is_uploaddisabled())
 
+    def test_is_uploaddisabled_wd(self):
         site = self.get_site('wikidatatest')
         self.assertTrue(site.is_uploaddisabled())
 

-- 
To view, visit https://gerrit.wikimedia.org/r/159394
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: Iac9567084dca017a1ac4ff07e4d0c994b51d79e5
Gerrit-PatchSet: 19
Gerrit-Project: pywikibot/core
Gerrit-Branch: master
Gerrit-Owner: Mpaa <[email protected]>
Gerrit-Reviewer: John Vandenberg <[email protected]>
Gerrit-Reviewer: Ladsgroup <[email protected]>
Gerrit-Reviewer: Merlijn van Deen <[email protected]>
Gerrit-Reviewer: Mpaa <[email protected]>
Gerrit-Reviewer: XZise <[email protected]>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to