XZise has submitted this change and it was merged.
Change subject: TokenWallet: new methods and refactoring ......................................................................
TokenWallet: new methods and refactoring
TokenWallet raises KeyError if user retrieves token for a not allowed action.
Site.preload_tokens () renamed site.get_tokens().
In addition: - Raise Error in TokenWallet if User has no rights - Exception changed to Error. - Make the message more explicit: Action 'patrol' is not allowed for user X on Y wiki. - Added __contains__ method. - Added __str__ method. - Added __repr__ method. - Added cache for not available tokens, to avoid to request them again. - Moved cache for available tokens in TokenWallet() - Preload all tokens available when requested for one token.
Change-Id: Iac9567084dca017a1ac4ff07e4d0c994b51d79e5 --- M pywikibot/site.py M tests/site_tests.py 2 files changed, 233 insertions(+), 54 deletions(-)
Approvals: XZise: Looks good to me, approved Mpaa: Looks good to me, but someone else must approve
diff --git a/pywikibot/site.py b/pywikibot/site.py index a5c1754..53edb4c 100644 --- a/pywikibot/site.py +++ b/pywikibot/site.py @@ -1185,20 +1185,61 @@
def __init__(self, site): self.site = site - self.site._tokens = {} - # TODO: Fetch that from the API with paraminfo - self.special_names = set(['deleteglobalaccount', 'patrol', 'rollback', - 'setglobalaccountstatus', 'userrights', - 'watch']) + self._tokens = {} + self.failed_cache = set() # cache unavailable tokens. + + def load_tokens(self, types, all=False): + """Preload one or multiple tokens.""" + assert(self.site.logged_in()) + + self._tokens.setdefault(self.site.user(), {}).update( + self.site.get_tokens(types, all=all)) + + # Preload all only the first time. + # When all=True types is extended in site.get_tokens(). + # Keys not recognised as tokens, are cached so they are not requested + # any longer. + if all: + for key in types: + if key not in self._tokens[self.site.user()]: + self.failed_cache.add((self.site.user(), key))
def __getitem__(self, key): - storage = self.site._tokens.setdefault(self.site.user(), {}) - if (LV(self.site.version()) >= LV('1.24wmf19') - and key not in self.special_names): - key = 'csrf' - if key not in storage: - self.site.preload_tokens([key]) - return storage[key] + assert(self.site.logged_in()) + + user_tokens = self._tokens.setdefault(self.site.user(), {}) + # always preload all for users without tokens + failed_cache_key = (self.site.user(), key) + + try: + key = self.site.validate_tokens([key])[0] + except IndexError: + raise Error( + u"Requested token '{0}' is invalid on {1} wiki." + .format(key, self.site)) + + if (key not in user_tokens and + failed_cache_key not in self.failed_cache): + self.load_tokens([key], all=not user_tokens) + + if key in user_tokens: + return user_tokens[key] + else: + # token not allowed for self.site.user() on self.site + self.failed_cache.add(failed_cache_key) + # to be changed back to a plain KeyError? + raise Error( + u"Action '{0}' is not allowed for user {1} on {2} wiki." + .format(key, self.site.user(), self.site)) + + def __contains__(self, key): + return key in self._tokens.setdefault(self.site.user(), {}) + + def __str__(self): + return self._tokens.__str__() + + def __repr__(self): + return self._tokens.__repr__()
class APISite(BaseSite): @@ -1228,12 +1269,63 @@ # Pages; see method docs for details) -- #
+ # Constants for token management. + # For all MediaWiki versions prior to 1.20. + TOKENS_0 = set(['edit', + 'delete', + 'protect', + 'move', + 'block', + 'unblock', + 'email', + 'import', + 'watch', + ]) + + # For all MediaWiki versions, with 1.20 <= version < 1.24wmf19 + TOKENS_1 = set(['block', + 'centralauth', + 'delete', + 'deleteglobalaccount', + 'edit', + 'email', + 'import', + 'move', + 'options', + 'patrol', + 'protect', + 'setglobalaccountstatus', + 'unblock', + 'watch', + ]) + + # For all MediaWiki versions >= 1.24wmf19 + TOKENS_2 = set(['csrf', + 'deleteglobalaccount', + 'patrol', + 'rollback', + 'setglobalaccountstatus', + 'userrights', + 'watch', + ]) + def __init__(self, code, fam=None, user=None, sysop=None): """ Constructor. """ BaseSite.__init__(self, code, fam, user, sysop) self._msgcache = {} self._loginstatus = LoginStatus.NOT_ATTEMPTED self._siteinfo = Siteinfo(self) + self.tokens = TokenWallet(self) + + def __getstate__(self): + """ Remove token wallet before pickling. """ + new = super(APISite, self).__getstate__() + del new['tokens'] + return new + + def __setstate__(self, attrs): + """ Restore things removed in __getstate__. """ + super(APISite, self).__setstate__() self.tokens = TokenWallet(self)
@staticmethod @@ -2266,15 +2358,39 @@ api.update_page(page, pagedata) yield page
- def preload_tokens(self, types): + def validate_tokens(self, types): + """Validate if requested tokens are acceptable. + + Valid tokens depend on mw version. + + """ + + _version = LV(self.version()) + if _version < LV('1.20'): + valid_types = [token for token in types if token in self.TOKENS_0] + elif _version < LV('1.24wmf19'): + valid_types = [token for token in types if token in self.TOKENS_1] + else: + valid_types = [] + for token in types: + if ((token in self.TOKENS_0 or token in self.TOKENS_1) and + token not in self.TOKENS_2): + token = 'csrf' + if token in self.TOKENS_2: + valid_types.append(token) + + return valid_types + + def get_tokens(self, types, all=False): """Preload one or multiple tokens.
For all MediaWiki versions prior to 1.20, only one token can be - retrieved at once. For MediaWiki versions since 1.24wmfXXX a new token + retrieved at once. + For MediaWiki versions since 1.24wmfXXX a new token system was introduced which reduced the amount of tokens available. Most of them were merged into the 'csrf' token. If the token type in - the parameter is not known it'll default to the 'csrf' token. The other - token types available are: + the parameter is not known it will default to the 'csrf' token. + The other token types available are: - deleteglobalaccount - patrol - rollback @@ -2285,34 +2401,60 @@ @param types: the types of token (e.g., "edit", "move", "delete"); see API documentation for full list of types @type types: iterable + @param all: load all available tokens + @type all: bool + + return: a dict with retrieved valid tokens. + """ - storage = self._tokens.setdefault(self.user(), {}) - if LV(self.version()) < LV('1.20'): - for tokentype in types: + + def warn_handler(mod, text): + """Filter warnings for not available tokens.""" + return re.match(r'Action '\w+' is not allowed for the current user', + text) + + user_tokens = {} + _version = LV(self.version()) + if _version < LV('1.20'): + if all: + types.extend(self.TOKENS_0) + for tokentype in self.validate_tokens(types): query = api.PropertyGenerator('info', titles='Dummy page', intoken=tokentype, site=self) + query.request._warning_handler = warn_handler + for item in query: pywikibot.debug(unicode(item), _logger) if (tokentype + 'token') in item: - storage[tokentype] = item[tokentype + 'token'] + user_tokens[tokentype] = item[tokentype + 'token'] + else: - if LV(self.version()) < LV('1.24wmf19'): - data = api.Request(site=self, action='tokens', - type='|'.join(types)).submit() + if _version < LV('1.24wmf19'): + if all: + types.extend(self.TOKENS_1) + req = api.Request(site=self, action='tokens', + type='|'.join(self.validate_tokens(types))) else: - new_tokens = [token if token in self.tokens.special_names else 'csrf' - for token in types] - data = api.Request(site=self, action='query', meta='tokens', - type='|'.join(new_tokens)).submit() - if 'query' in data: - data = data['query'] + if all: + types.extend(self.TOKENS_2) + + req = api.Request(site=self, action='query', meta='tokens', + type='|'.join(self.validate_tokens(types))) + + req._warning_handler = warn_handler + data = req.submit() + + if 'query' in data: + data = data['query']
if 'tokens' in data and data['tokens']: - storage.update(dict((key[:-5], val) + user_tokens = dict((key[:-5], val) for key, val in data['tokens'].items() - if val != '+\')) + if val != '+\') + + return user_tokens
@deprecated("the 'tokens' property") def token(self, page, tokentype): diff --git a/tests/site_tests.py b/tests/site_tests.py index 778d5e6..50fae12 100644 --- a/tests/site_tests.py +++ b/tests/site_tests.py @@ -222,28 +222,6 @@ if a: self.assertEqual(a[0], mainpage)
- def testTokens(self): - """Test ability to get page tokens.""" - mysite = self.get_site() - for ttype in ("edit", "move"): # token types for non-sysops - try: - token = self.site.tokens[ttype] - except KeyError: - raise unittest.SkipTest( - "Testing '%s' token not possible with user on %s" - % (ttype, self.site)) - self.assertIsInstance(token, basestring) - self.assertEqual(token, mysite.tokens[ttype]) - - def testInvalidToken(self): - mysite = self.get_site() - if LV(mysite.version()) >= LV('1.23wmf19'): - # Currently with the new token API all unknown types are treated - # as csrf tokens, so it won't throw an error here - # a patch is in development: https://gerrit.wikimedia.org/r/#/c/159394 - raise unittest.SkipTest('No invalid token with the new token API possible') - self.assertRaises(KeyError, lambda t: mysite.tokens[t], "invalidtype") - def testPreload(self): """Test that preloading works.""" mysite = self.get_site() @@ -1035,6 +1013,64 @@ # and the other following methods in site.py
+class TestSiteTokens(DefaultSiteTestCase): + + """Test cases for tokens in Site methods.""" + + user = True + + def setUp(self): + """Store version.""" + self.mysite = self.get_site() + self._version = LV(self.mysite.version()) + self.orig_version = self.mysite.version + + def tearDown(self): + """Restore version.""" + self.mysite.version = self.orig_version + + def test_tokens_in_mw_119(self): + """Test ability to get page tokens.""" + self.mysite.version = lambda: '1.19' + for ttype in ("edit", "move"): # token types for non-sysops + token = self.site.tokens[ttype] + self.assertIsInstance(token, basestring) + self.assertEqual(token, self.mysite.tokens[ttype]) + # test __contains__ + self.assertIn("edit", self.mysite.tokens) + + def test_tokens_in_mw_120_124wmf18(self): + """Test ability to get page tokens.""" + if self._version < LV('1.20'): + raise unittest.SkipTest( + u'Site %s version %s is too low for this tests.' + % (self.mysite, self._version)) + self.mysite.version = lambda: '1.21' + for ttype in ("edit", "move"): # token types for non-sysops + token = self.mysite.tokens[ttype] + self.assertIsInstance(token, basestring) + self.assertEqual(token, self.mysite.tokens[ttype]) + # test __contains__ + self.assertIn("edit", self.mysite.tokens) + + def test_tokens_in_mw_124wmf19(self): + """Test ability to get page tokens.""" + if self._version < LV('1.24wmf19'): + raise unittest.SkipTest( + u'Site %s version %s is too low for this tests.' + % (self.mysite, self._version)) + self.mysite.version = lambda: '1.24wmf20' + for ttype in ("edit", "move"): # token types for non-sysops + token = self.mysite.tokens[ttype] + self.assertIsInstance(token, basestring) + self.assertEqual(token, self.mysite.tokens[ttype]) + # test __contains__ + self.assertIn("csrf", self.mysite.tokens) + + def testInvalidToken(self): + self.assertRaises(pywikibot.Error, lambda t: self.mysite.tokens[t], "invalidtype") + + class TestSiteExtensions(WikimediaDefaultSiteTestCase):
"""Test cases for Site extensions.""" @@ -1327,10 +1363,11 @@ } }
- def test_is_uploaddisabled(self): + def test_is_uploaddisabled_wp(self): site = self.get_site('wikipediatest') self.assertFalse(site.is_uploaddisabled())
+ def test_is_uploaddisabled_wd(self): site = self.get_site('wikidatatest') self.assertTrue(site.is_uploaddisabled())
pywikibot-commits@lists.wikimedia.org