scaramanga

The hacker with the supernumerary nipple

16 March 2019

SQL-Like Sorting in Python

by Gianni Tedesco

Sorting things in python can often be a pain point. You need to be familiar with the decorate-sort-undecorate paradigm. Also known as the Schwartzian transform. In python3 this technique replaced the old system of comparator functions. Lots of words have been expended discussing this in the quite excellent python documentation so I won’t try and duplicate any of that here.

For what it’s worth, I actually agree with the python implementors that this is the better, less error-prone approach to the problem. And it can have some performance benefits too.

But what if you crave simpler times? Flash back to when Eric Clapton’s cover of “I shot the sheriff” was blaring out of everyone’s 8-tracks. SQL gives us a pretty nice and intuitive formalisation for defining orderings over bags of tuples. And iterable sequences of objects in python can be considered exactly that.

Let’s get the preliminaries out of the way. This is what the use-case looks like:

from types import SimpleNamespace
from enum import Enum
class Order(Enum):
    ASC = 0
    DESC = 1

keyfunc = sqlordering(('x', Order.ASC), ('y', Order.DESC))

input = [
    SimpleNamespace(x=999, y='aardvark'),
    SimpleNamespace(x=999, y='xylophone'),
    SimpleNamespace(x=111, y='xylophone'),
    SimpleNamespace(x=111, y='aardvark'),
    # etc..
]

for item in sorted(input, key=keyfunc):
    print(item)
# SimpleNamespace(x=111, y='xylophone')
# SimpleNamespace(x=111, y='aardvark')
# SimpleNamespace(x=999, y='xylophone')
# SimpleNamespace(x=999, y='aardvark')

The key function

Let’s work backwards from the key function required by list.sort(). The simplest way to produce a total ordering is to define a class whose constructor takes one argument, the item that we want to compare/sort. That class should keep a reference to the original object and then define the six rich comparison methods.

We’ll use the functools.total_ordering decorator to fill in a lot of boilerplate for us. This way we’ll only need to define an __lt__ and an __eq__ method in our class and the decorator will fill in the rest for us. All those other methods can be defined in terms of just those two (e.g. __ne__ is equivalent to not __eq__).

Regardless, it looks like sorted and list.sort use only the __lt__ comparison so those methods ought never be called and there should be no performance hit. But we may as well define them just because there’s no such thing as a guarantee in life. But that’s another topic, and you can read about my failures and disappointments in life in my upcoming memoir.

def sqlordering(*spec):
    @total_ordering
    class K:
        __slots__ = ['_obj']
        __hash__ = None
        def __init__(self, obj):
            self._obj = obj
        def __lt__(self, other):
	    # Implement this
	    pass
        def __eq__(self, other):
	    # Implement this
	    pass
    return K

How to define __lt__ and __eq__

The first thing we need to do is iterate over the sort specification. For any field sorted ASCending we’ll use operator.lt (less-than) and for any field sorted DESCending we’ll use the opposite: operator.ge (greater-than-or-equal).

Then we’ll need to use partial application to combine the attribute lookups with the operator.

For example, to create a function f which takes two objects and returns True if attribute x on the first argument is less than attribute x on the second argument, we can do:

from functools import partial
# a is attribute name, a string
# c is comparison operator, a callable
# x,y will remain free paramaters and can be any type
cmp_func = lambda a,c,x,y:c(getattr(x,a),getattr(y,a))
f = partial(cmp_func, 'x', operator.lt)

Now, with that, we need to create a higher order function which combines this sequence of functions to create the two operators for our total ordering.

For the equals operator it’s straightforward, two items are equal if and only if all attributes are equal. So we just iterate over the sequence of equality functions applying them as we go. If any of them fail, the comparison fails.

For the less-than operator it’s a little bit more involved. Recall the semantics of SQL sorts. You can find that in section 8.2 “<comparison predicate>” of ISO/IEC 9075-2:2016, general rules, paragraph 1, clause h:

The relative position of row P is before row Q if PVn precedes QVn for some n, 1 (one) ≤ n ≤ N, and PVi is not distinct from QVi for all i < n.

In other words, if two rows are equal in the first comparison key, then we proceed to ordering by the second comparison key, and so on.

So with that we can go ahead and complete the implementation. Here it is:

from typing import Iterable, Tuple
from functools import total_ordering, partial
from operator import lt, ge, eq

def _cmp_func(*args):
    spec : Iterable[Tuple[str, Order]] = args
    cmp_func = lambda a,c,x,y:c(getattr(x,a),getattr(y,a))
    for attr, sense in spec:
        c_eq = partial(cmp_func, attr, eq)
        if sense == Order.ASC:
            c_lt = partial(cmp_func, attr, lt)
        elif sense == Order.DESC:
            c_lt = partial(cmp_func, attr, ge)
        else:
            raise ValueError
        yield (c_lt, c_eq)

def sqlordering(*spec):
    funcs = tuple(_cmp_func(*spec))

    @total_ordering
    class K:
        __slots__ = ['_obj']
        __hash__ = None
        def __init__(self, obj):
            self._obj = obj
        def __lt__(self, other):
            for ltcmp, eqcmp in funcs:
                if ltcmp(self._obj, other._obj):
                    return True
                elif not eqcmp(self._obj, other._obj):
                    return False
            return False
        def __eq__(self, other):
            for ltcmp, eqcmp in funcs:
                if not eqcmp(self._obj, other._obj):
                    return False
            return True
    return K
tags: python - sql