User:Surjection/categorylister2.py

import urllib.parse
import urllib.request
import json
import operator
import sys
from collections import OrderedDict
from functools import reduce


_DEBUG = False


class APIURL():
    def __init__(self, domain, **params):
        self.domain = domain
        self.params = OrderedDict(params)

    def make(self):
        parstring = '&'.join(
            f'{key}={urllib.parse.quote(value)}' for key, value in self.params.items())
        return f'https://{self.domain}/w/api.php?{parstring}'

    def copy(self):
        return APIURL(self.domain, **self.params)


class SiteNameParser():
    suffixes = {'wiki': '.wikipedia.org', 'wikt': '.wiktionary.org'}

    def parse(self, site):
        for suffix in SiteNameParser.suffixes:
            if site.endswith(suffix):
                return site[:-len(suffix)] + SiteNameParser.suffixes[suffix]
        raise ValueError(f'unrecognized site name: {site}')


class CategoryNameParser():
    def __init__(self):
        self.sitenameparser = SiteNameParser()

    def parse(self, cat):
        if '|' not in cat:
            raise ValueError('must be in format site|category')
        site, category = cat.split('|', 1)
        deep = False
        if site.startswith('@'):
            site, deep = site[1:], True
        domain = self.sitenameparser.parse(site)
        return {'url': APIURL(domain, format='json', action='query',
                              list='categorymembers', cmlimit='100',
                              cmtitle='Category:' + category),
                'deep': deep}

    def subcategory(self, url, category):
        return {'url': APIURL(url.domain, format='json', action='query',
                              list='categorymembers', cmlimit='100',
                              cmtitle=category),
                'deep': True}


class MWAPI():
    def request(self, url):
        if _DEBUG:
            print("Making API request to", url.make(), file=sys.stderr)
        with urllib.request.urlopen(url.make()) as req:
            result = json.loads(req.read().decode('utf-8'))
        return result

    def categorymembers(self, url):
        while True:
            result = self.request(url)
            for c in result['query']['categorymembers']:
                yield c
            if 'continue' in result and 'cmcontinue' in result['continue']:
                url.params['cmcontinue'] = result['continue']['cmcontinue']
            else:
                break


class StdinLister():
    def collect(self):
        lines = []
        try:
            while True:
                lines.append(input())
        except EOFError:
            pass
        return lines


class CategoryLister():
    def __init__(self):
        self.catparser = CategoryNameParser()
        self.mwapi = MWAPI()

    def collect_sub(self, url, deep, include_cats=False, ns0=False):
        pages = []
        for page in self.mwapi.categorymembers(url):
            if (page['ns'] == 0 or not ns0) and (page['ns'] != 14 or include_cats):
                pages.append(page['title'])
            if deep and page['ns'] == 14:
                suburl = self.catparser.subcategory(url, page['title'])['url']
                pages += self.collect_sub(suburl, True, include_cats, ns0)
        return pages

    def collect(self, category, include_cats=False, ns0=False):
        if category == '-':
            return StdinLister().collect()
        data = self.catparser.parse(category)
        url, deep = data['url'], data['deep']
        return self.collect_sub(url, deep, include_cats, ns0)


class MultiCategoryLister():
    def __init__(self, operation):
        self.lister = CategoryLister()
        self.operation = operation

    def collect(self, categories, include_cats=False, ns0=False):
        sets = [set(self.lister.collect(category, include_cats, ns0))
                for category in categories]
        return list(sorted(self.operation(sets)))


def set_union(sets):
    return reduce(operator.or_, sets)


def set_intersection(sets):
    return reduce(operator.and_, sets)


def set_difference(sets):
    return reduce(operator.sub, sets)


def set_pairwise_intersection(sets):
    counts = {item: sum(int(item in set) for set in sets) for item in set_union(sets)}
    return set(item for item in counts.keys() if counts[item] > 1)


def set_symmetric_difference(sets):
    return reduce(operator.xor, sets)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--union', help='all pages in any of the given categories', action='store_true')
    parser.add_argument(
        '--intersection', help='all pages in all of the given categories', action='store_true')
    parser.add_argument(
        '--pairwise-intersection', help='all pages in at least two of the given categories', action='store_true')
    parser.add_argument(
        '--difference', help='all pages in only the first of the given categories', action='store_true')
    parser.add_argument(
        '--symmetric-difference', help='all pages in only an odd number of the given categories', action='store_true')
    parser.add_argument(
        '--cats', help='include categories in list', action='store_true')
    parser.add_argument(
        '--ns0', help='only consider pages in namespace 0 (main namespace)', action='store_true')
    parser.add_argument(
        '--limit', help='limit final result to first N pages (0 for all)', nargs='?', default=0, type=int)
    parser.add_argument(
        '--output', help='File name, If not specified, goes to stdout', nargs='?', default=None)
    parser.add_argument('category', nargs='+',
                        help='In the format site|categoryname, such as "enwikt|English lemmas" (prefix with @ to use deep search); use - for stdin')
    args = parser.parse_args()

    modes = [('union', set_union), ('intersection', set_intersection),
             ('pairwise_intersection', set_pairwise_intersection), ('difference', set_difference),
             ('symmetric_difference', set_symmetric_difference)]
    modeflags = [getattr(args, name) for name, func in modes]
    if modeflags.count(True) != 1:
        if len(args.category) > 1 or modeflags.count(True) > 1:
            parser.print_help(sys.stderr)
            parser.error('must specify exactly one mode')
        # --union by default if only one category
        operation = set_union
    else:
        operation = next(func for name, func in modes if getattr(args, name))
    results = MultiCategoryLister(operation).collect(args.category, args.cats, args.ns0)
    output = sys.stdout if args.output is None else open(
        args.output, 'w', encoding='utf-8')
    if args.limit > 0:
        results = results[:args.limit]
    for page in results:
        print(page, file=output)
    print(f'Total: {len(results)}', file=sys.stderr)