import urllib.parse
import urllib.request
import json
import operator
import sys
from collections import OrderedDict
from functools import reduce
_DEBUG = False
class APIURL():
def __init__(self, domain, **params):
self.domain = domain
self.params = OrderedDict(params)
def make(self):
parstring = '&'.join(
f'{key}={urllib.parse.quote(value)}' for key, value in self.params.items())
return f'https://{self.domain}/w/api.php?{parstring}'
def copy(self):
return APIURL(self.domain, **self.params)
class SiteNameParser():
suffixes = {'wiki': '.wikipedia.org', 'wikt': '.wiktionary.org'}
def parse(self, site):
for suffix in SiteNameParser.suffixes:
if site.endswith(suffix):
return site[:-len(suffix)] + SiteNameParser.suffixes[suffix]
raise ValueError(f'unrecognized site name: {site}')
class CategoryNameParser():
def __init__(self):
self.sitenameparser = SiteNameParser()
def parse(self, cat):
if '|' not in cat:
raise ValueError('must be in format site|category')
site, category = cat.split('|', 1)
deep = False
if site.startswith('@'):
site, deep = site[1:], True
domain = self.sitenameparser.parse(site)
return {'url': APIURL(domain, format='json', action='query',
list='categorymembers', cmlimit='100',
cmtitle='Category:' + category),
'deep': deep}
def subcategory(self, url, category):
return {'url': APIURL(url.domain, format='json', action='query',
list='categorymembers', cmlimit='100',
cmtitle=category),
'deep': True}
class MWAPI():
def request(self, url):
if _DEBUG:
print("Making API request to", url.make(), file=sys.stderr)
with urllib.request.urlopen(url.make()) as req:
result = json.loads(req.read().decode('utf-8'))
return result
def categorymembers(self, url):
while True:
result = self.request(url)
for c in result['query']['categorymembers']:
yield c
if 'continue' in result and 'cmcontinue' in result['continue']:
url.params['cmcontinue'] = result['continue']['cmcontinue']
else:
break
class StdinLister():
def collect(self):
lines = []
try:
while True:
lines.append(input())
except EOFError:
pass
return lines
class CategoryLister():
def __init__(self):
self.catparser = CategoryNameParser()
self.mwapi = MWAPI()
def collect_sub(self, url, deep, include_cats=False, ns0=False):
pages = []
for page in self.mwapi.categorymembers(url):
if (page['ns'] == 0 or not ns0) and (page['ns'] != 14 or include_cats):
pages.append(page['title'])
if deep and page['ns'] == 14:
suburl = self.catparser.subcategory(url, page['title'])['url']
pages += self.collect_sub(suburl, True, include_cats, ns0)
return pages
def collect(self, category, include_cats=False, ns0=False):
if category == '-':
return StdinLister().collect()
data = self.catparser.parse(category)
url, deep = data['url'], data['deep']
return self.collect_sub(url, deep, include_cats, ns0)
class MultiCategoryLister():
def __init__(self, operation):
self.lister = CategoryLister()
self.operation = operation
def collect(self, categories, include_cats=False, ns0=False):
sets = [set(self.lister.collect(category, include_cats, ns0))
for category in categories]
return list(sorted(self.operation(sets)))
def set_union(sets):
return reduce(operator.or_, sets)
def set_intersection(sets):
return reduce(operator.and_, sets)
def set_difference(sets):
return reduce(operator.sub, sets)
def set_pairwise_intersection(sets):
counts = {item: sum(int(item in set) for set in sets) for item in set_union(sets)}
return set(item for item in counts.keys() if counts[item] > 1)
def set_symmetric_difference(sets):
return reduce(operator.xor, sets)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--union', help='all pages in any of the given categories', action='store_true')
parser.add_argument(
'--intersection', help='all pages in all of the given categories', action='store_true')
parser.add_argument(
'--pairwise-intersection', help='all pages in at least two of the given categories', action='store_true')
parser.add_argument(
'--difference', help='all pages in only the first of the given categories', action='store_true')
parser.add_argument(
'--symmetric-difference', help='all pages in only an odd number of the given categories', action='store_true')
parser.add_argument(
'--cats', help='include categories in list', action='store_true')
parser.add_argument(
'--ns0', help='only consider pages in namespace 0 (main namespace)', action='store_true')
parser.add_argument(
'--limit', help='limit final result to first N pages (0 for all)', nargs='?', default=0, type=int)
parser.add_argument(
'--output', help='File name, If not specified, goes to stdout', nargs='?', default=None)
parser.add_argument('category', nargs='+',
help='In the format site|categoryname, such as "enwikt|English lemmas" (prefix with @ to use deep search); use - for stdin')
args = parser.parse_args()
modes = [('union', set_union), ('intersection', set_intersection),
('pairwise_intersection', set_pairwise_intersection), ('difference', set_difference),
('symmetric_difference', set_symmetric_difference)]
modeflags = [getattr(args, name) for name, func in modes]
if modeflags.count(True) != 1:
if len(args.category) > 1 or modeflags.count(True) > 1:
parser.print_help(sys.stderr)
parser.error('must specify exactly one mode')
# --union by default if only one category
operation = set_union
else:
operation = next(func for name, func in modes if getattr(args, name))
results = MultiCategoryLister(operation).collect(args.category, args.cats, args.ns0)
output = sys.stdout if args.output is None else open(
args.output, 'w', encoding='utf-8')
if args.limit > 0:
results = results[:args.limit]
for page in results:
print(page, file=output)
print(f'Total: {len(results)}', file=sys.stderr)