jenkins-bot submitted this change.
[IMPR] Rewrite tools.intersect_generators
- Rewrite tools.intersect_generators using itertools.zip_longest
instead of threads. This makes the intersect_generators running
up to 10'000 times faster.
- use a set for seen instead a dict
- use the iterable index for cache instead of thread
- explicit increase cache Counter by 1 instead of update it by a single item
- use a set for active iterables instead of active_count() method
- Move BasicGeneratorIntersectTestCase and GeneratorIntersectTestCase
from thread_tests to tools_tests.py because intersect_generators is
no longer done by threads.
- early return if our cache's keys is a subset of active iterables
- Add more tests to TestFactoryGenerator.test_intersect_generator
Bug: T293276
Bug: T85623
Change-Id: Ic123eb821abc9abb44d5cab6973e58da911d6d76
---
M pywikibot/tools/__init__.py
M tests/pagegenerators_tests.py
M tests/thread_tests.py
M tests/tools_tests.py
4 files changed, 133 insertions(+), 121 deletions(-)
diff --git a/pywikibot/tools/__init__.py b/pywikibot/tools/__init__.py
index a988b10..eb4fe40 100644
--- a/pywikibot/tools/__init__.py
+++ b/pywikibot/tools/__init__.py
@@ -801,16 +801,15 @@
def intersect_generators(*iterables, allow_duplicates: bool = False):
- """Intersect generators listed in iterables.
+ """Generator of intersect iterables.
- Yield items only if they are yielded by all generators of iterables.
- Threads (via ThreadedGenerator) are used in order to run generators
- in parallel, so that items can be yielded before generators are
- exhausted.
+ Yield items only if they are yielded by all iterables. zip_longest
+ is used to retrieve items from all iterables in parallel, so that
+ items can be yielded before iterables are exhausted.
- Threads are stopped when they are either exhausted or Ctrl-C is pressed.
- Quitting before all generators are finished is attempted if
- there is no more chance of finding an item in all queues.
+ Generator is stopped when all iterables are exhausted. Quitting
+ before all iterables are finished is attempted if there is no more
+ chance of finding an item in all of them.
Sample:
@@ -834,6 +833,9 @@
``allow_duplicates`` as positional argument,
``iterables`` as list type
+ .. versionchanged:: 7.0
+ Reimplemented without threads which is up to 10'000 times faster
+
:param iterables: page generators
:param allow_duplicates: optional keyword argument to allow duplicates
if present in all generators
@@ -861,73 +863,55 @@
yield from iterables[0]
return
- # If any generator is empty, no pages are going to be returned
+ # If any iterable is empty, no pages are going to be returned
for source in iterables:
if not source:
- debug('At least one generator ({!r}) is empty and execution was '
+ debug('At least one iterable ({!r}) is empty and execution was '
'skipped immediately.'.format(source), 'intersect')
return
- # Item is cached to check that it is found n_gen
- # times before being yielded.
+ # Item is cached to check that it is found n_gen times
+ # before being yielded.
cache = collections.defaultdict(collections.Counter)
n_gen = len(iterables)
- # Class to keep track of alive threads.
- # Start new threads and remove completed threads.
- thrlist = ThreadList()
+ ones = collections.Counter(range(n_gen))
+ active_iterables = set(range(n_gen))
+ seen = set()
- for source in iterables:
- threaded_gen = ThreadedGenerator(name=repr(source), target=source)
- threaded_gen.daemon = True
- thrlist.append(threaded_gen)
+ # Get items from iterables in a round-robin way.
+ sentinel = object()
+ for items in zip_longest(*iterables, fillvalue=sentinel):
+ for index, item in enumerate(items):
- ones = collections.Counter(thrlist)
- seen = {}
+ if item is sentinel:
+ active_iterables.discard(index)
+ continue
- while True:
- # Get items from queues in a round-robin way.
- for t in thrlist:
- try:
- # TODO: evaluate if True and timeout is necessary.
- item = t.queue.get(True, 0.1)
+ if not allow_duplicates and hash(item) in seen:
+ continue
- if not allow_duplicates and hash(item) in seen:
- continue
+ # Each cache entry is a Counter of iterables' index
+ cache[item][index] += 1
- # Cache entry is a Counter of ThreadedGenerator objects.
- cache[item].update([t])
- if len(cache[item]) == n_gen:
- if allow_duplicates:
- yield item
- # Remove item from cache if possible.
- if all(el == 1 for el in cache[item].values()):
- cache.pop(item)
- else:
- cache[item] -= ones
- else:
- yield item
- cache.pop(item)
- seen[hash(item)] = True
+ if len(cache[item]) == n_gen:
+ yield item
- active = thrlist.active_count()
- max_cache = n_gen
- if cache.values():
- max_cache = max(len(v) for v in cache.values())
- # No. of active threads is not enough to reach n_gen.
- # We can quit even if some thread is still active.
- # There could be an item in all generators which has not yet
- # appeared from any generator. Only when we have lost one
- # generator, then we can bail out early based on seen items.
- if active < n_gen and n_gen - max_cache > active:
- thrlist.stop_all()
- return
- except queue.Empty:
- pass
- except KeyboardInterrupt:
- thrlist.stop_all()
- # All threads are done.
- if thrlist.active_count() == 0:
+ # Remove item from cache if possible or decrease Counter entry
+ if not allow_duplicates:
+ del cache[item]
+ seen.add(hash(item))
+ elif cache[item] == ones:
+ del cache[item]
+ else:
+ cache[item] -= ones
+
+ # We can quit if an iterable is exceeded and cached iterables is
+ # a subset of active iterables.
+ if len(active_iterables) < n_gen:
+ cached_iterables = set(
+ chain.from_iterable(v.keys() for v in cache.values()))
+ if cached_iterables <= active_iterables:
return
diff --git a/tests/pagegenerators_tests.py b/tests/pagegenerators_tests.py
index beb9c92..6129d5e 100644
--- a/tests/pagegenerators_tests.py
+++ b/tests/pagegenerators_tests.py
@@ -30,7 +30,7 @@
TestCase,
WikidataTestCase,
)
-from tests.thread_tests import GeneratorIntersectTestCase
+from tests.tools_tests import GeneratorIntersectTestCase
LINKSEARCH_MSG = (r'.*pywikibot\.pagegenerators\.LinksearchPageGenerator .*'
@@ -894,12 +894,39 @@
self.assertEqual(tuple(gen), ('A', 'B', 'C'))
def test_intersect_generator(self):
- """Test getCombinedGenerator with generator parameter."""
+ """Test getCombinedGenerator with -intersect option."""
gf = pagegenerators.GeneratorFactory()
gf.handle_arg('-intersect')
- gf.gens = ['Python 3.7-dev']
- gen = gf.getCombinedGenerator(gen='Pywikibot 3.0.dev')
- self.assertEqual(''.join(gen), 'Pyot 3.dev')
+
+ # check wether the generator works for both directions
+ patterns = ['Python 3.7-dev', 'Pywikibot 7.0.dev']
+ for index in range(2):
+ with self.subTest(index=index):
+ gf.gens = [patterns[index]]
+ gen = gf.getCombinedGenerator(gen=patterns[index - 1])
+ self.assertEqual(''.join(gen), 'Pyot 7.dev')
+
+ # check wether the generator works for a very long text
+ patterns.append('PWB 7+ unittest developed with a very long text.')
+ with self.subTest(patterns=patterns):
+ gf.gens = patterns
+ gen = gf.getCombinedGenerator()
+ self.assertEqual(''.join(gen), 'P 7tedvoy.')
+
+ # check whether an early stop fits
+ with self.subTest(comment='Early stop'):
+ gf.gens = 'ABC', 'A Big City'
+ gen = gf.getCombinedGenerator()
+ self.assertEqual(''.join(gen), 'ABC')
+
+ with self.subTest(comment='Commutative'):
+ gf.gens = 'ABB', 'BB'
+ gen1 = gf.getCombinedGenerator()
+ gf2 = pagegenerators.GeneratorFactory()
+ gf2.handle_arg('-intersect')
+ gf2.gens = 'BB', 'ABB'
+ gen2 = gf2.getCombinedGenerator()
+ self.assertEqual(list(gen1), list(gen2))
def test_ns(self):
"""Test namespace option."""
diff --git a/tests/thread_tests.py b/tests/thread_tests.py
index 44ec655..329017f 100644
--- a/tests/thread_tests.py
+++ b/tests/thread_tests.py
@@ -5,10 +5,9 @@
# Distributed under the terms of the MIT license.
#
import unittest
-from collections import Counter
from contextlib import suppress
-from pywikibot.tools import ThreadedGenerator, intersect_generators
+from pywikibot.tools import ThreadedGenerator
from tests.aspects import TestCase
@@ -39,58 +38,6 @@
self.assertEqual(list(thd_gen), list(iterable))
-class GeneratorIntersectTestCase(TestCase):
-
- """Base class for intersect_generators test cases."""
-
- def assertEqualItertools(self, gens):
- """Assert intersect_generators result is same as set intersection."""
- # If they are a generator, we need to convert to a list
- # first otherwise the generator is empty the second time.
- datasets = [list(gen) for gen in gens]
- set_result = set(datasets[0]).intersection(*datasets[1:])
- result = list(intersect_generators(*datasets))
-
- self.assertCountEqual(set(result), result)
- self.assertCountEqual(result, set_result)
-
- def assertEqualItertoolsWithDuplicates(self, gens):
- """Assert intersect_generators result equals Counter intersection."""
- # If they are a generator, we need to convert to a list
- # first otherwise the generator is empty the second time.
- datasets = [list(gen) for gen in gens]
- counter_result = Counter(datasets[0])
- for dataset in datasets[1:]:
- counter_result = counter_result & Counter(dataset)
- counter_result = list(counter_result.elements())
- result = list(intersect_generators(*datasets, allow_duplicates=True))
- self.assertCountEqual(counter_result, result)
-
-
-class BasicGeneratorIntersectTestCase(GeneratorIntersectTestCase):
-
- """Disconnected intersect_generators test cases."""
-
- net = False
-
- def test_intersect_basic(self):
- """Test basic intersect without duplicates."""
- self.assertEqualItertools(['abc', 'db', 'ba'])
-
- def test_intersect_with_dups(self):
- """Test basic intersect with duplicates."""
- self.assertEqualItertools(['aabc', 'dddb', 'baa'])
-
- def test_intersect_with_accepted_dups(self):
- """Test intersect with duplicates accepted."""
- self.assertEqualItertoolsWithDuplicates(['abc', 'db', 'ba'])
- self.assertEqualItertoolsWithDuplicates(['aabc', 'dddb', 'baa'])
- self.assertEqualItertoolsWithDuplicates(['abb', 'bb'])
- self.assertEqualItertoolsWithDuplicates(['bb', 'abb'])
- self.assertEqualItertoolsWithDuplicates(['abbcd', 'abcba'])
- self.assertEqualItertoolsWithDuplicates(['abcba', 'abbcd'])
-
-
if __name__ == '__main__': # pragma: no cover
with suppress(SystemExit):
unittest.main()
diff --git a/tests/tools_tests.py b/tests/tools_tests.py
index 66a6bb5..8069880 100644
--- a/tests/tools_tests.py
+++ b/tests/tools_tests.py
@@ -9,7 +9,8 @@
import subprocess
import tempfile
import unittest
-from collections import OrderedDict
+
+from collections import Counter, OrderedDict
from collections.abc import Mapping
from contextlib import suppress
from importlib import import_module
@@ -18,6 +19,7 @@
from pywikibot.tools import (
classproperty,
has_module,
+ intersect_generators,
is_ip_address,
suppress_warnings,
)
@@ -731,6 +733,58 @@
self.assertEqual(Foo.bar, Foo._bar)
+class GeneratorIntersectTestCase(TestCase):
+
+ """Base class for intersect_generators test cases."""
+
+ def assertEqualItertools(self, gens):
+ """Assert intersect_generators result is same as set intersection."""
+ # If they are a generator, we need to convert to a list
+ # first otherwise the generator is empty the second time.
+ datasets = [list(gen) for gen in gens]
+ set_result = set(datasets[0]).intersection(*datasets[1:])
+ result = list(intersect_generators(*datasets))
+
+ self.assertCountEqual(set(result), result)
+ self.assertCountEqual(result, set_result)
+
+ def assertEqualItertoolsWithDuplicates(self, gens):
+ """Assert intersect_generators result equals Counter intersection."""
+ # If they are a generator, we need to convert to a list
+ # first otherwise the generator is empty the second time.
+ datasets = [list(gen) for gen in gens]
+ counter_result = Counter(datasets[0])
+ for dataset in datasets[1:]:
+ counter_result = counter_result & Counter(dataset)
+ counter_result = list(counter_result.elements())
+ result = list(intersect_generators(*datasets, allow_duplicates=True))
+ self.assertCountEqual(counter_result, result)
+
+
+class BasicGeneratorIntersectTestCase(GeneratorIntersectTestCase):
+
+ """Disconnected intersect_generators test cases."""
+
+ net = False
+
+ def test_intersect_basic(self):
+ """Test basic intersect without duplicates."""
+ self.assertEqualItertools(['abc', 'db', 'ba'])
+
+ def test_intersect_with_dups(self):
+ """Test basic intersect with duplicates."""
+ self.assertEqualItertools(['aabc', 'dddb', 'baa'])
+
+ def test_intersect_with_accepted_dups(self):
+ """Test intersect with duplicates accepted."""
+ self.assertEqualItertoolsWithDuplicates(['abc', 'db', 'ba'])
+ self.assertEqualItertoolsWithDuplicates(['aabc', 'dddb', 'baa'])
+ self.assertEqualItertoolsWithDuplicates(['abb', 'bb'])
+ self.assertEqualItertoolsWithDuplicates(['bb', 'abb'])
+ self.assertEqualItertoolsWithDuplicates(['abbcd', 'abcba'])
+ self.assertEqualItertoolsWithDuplicates(['abcba', 'abbcd'])
+
+
class TestMergeGenerator(TestCase):
"""Test merging generators."""
To view, visit change 608153. To unsubscribe, or for help writing mail filters, visit settings.