理解 Python super

今天在知乎回答了一个问题,居然一个赞都没有,也是神奇,毕竟这算是我非常认真答的题之一。既然如此就贴过来好了,有些内容之后再补充。

原问题

Python中既然可以直接通过父类名调用父类方法为什么还会存在super函数?
比如
class Child(Parent):
    def __init(self):
    Parent.__init
(self)
这种方式与super(Child, self).init()有区别么?

回答
针对你的问题,答案是可以,并没有区别。但是这题下的回答我感觉都不够好。

要谈论 super,首先我们应该无视 "super" 这个名字带给我们的干扰。

不要一说到 super 就想到父类!super 指的是 MRO 中的下一个类!
不要一说到 super 就想到父类!super 指的是 MRO 中的下一个类!
不要一说到 super 就想到父类!super 指的是 MRO 中的下一个类!

一说到 super 就想到父类这是初学者很容易犯的一个错误,也是我当年犯的错误。 忘记了这件事之后,再去看这篇文章:Python’s super() considered super! 这是 Raymond Hettinger 写的一篇文章,也是全世界公认的对 super 讲解最透彻的一篇文章,凡是讨论 super 都一定会提到它(当然还有一篇 Python's Super Considered Harmful)。

如果不想看长篇大论就去看这个答案super 其实干的是这件事:

def super(cls, inst):
    mro = inst.__class__.mro()
    return mro[mro.index(cls) + 1]

两个参数 cls 和 inst 分别做了两件事:
1. inst 负责生成 MRO 的 list
2. 通过 cls 定位当前 MRO 中的 index, 并返回 mro[index + 1]
这两件事才是 super 的实质,一定要记住!
MRO 全称 Method Resolution Order,它代表了类继承的顺序。后面详细说。

举个例子

class Root(object):
    def __init__(self):
        print("this is Root")

class B(Root):
    def __init__(self):
        print("enter B")
        # print(self)  # this will print <__main__.D object at 0x...>
        super(B, self).__init__()
        print("leave B")

class C(Root):
    def __init__(self):
        print("enter C")
        super(C, self).__init__()
        print("leave C")

class D(B, C):
    pass

d = D()
print(d.__class__.__mro__)

输出

enter B
enter C
this is Root
leave C
leave B
(<class '__main__.D'>, <class '__main__.B'>, <class '__main__.C'>, <class '__main__.Root'>, <type 'object'>)

知道了 super 和父类其实没有实质关联之后,我们就不难理解为什么 enter B 下一句是 enter C 而不是 this is Root(如果认为 super 代表“调用父类的方法”,会想当然的认为下一句应该是this is Root)。流程如下,在 B 的 __init__ 函数中:

super(B, self).__init__()

首先,我们获取 self.__class__.__mro__,注意这里的 self 是 D 的 instance 而不是 B 的

(<class '__main__.D'>, <class '__main__.B'>, <class '__main__.C'>, <class '__main__.Root'>, <type 'object'>)

然后,通过 B 来定位 MRO 中的 index,并找到下一个。显然 B 的下一个是 C。于是,我们调用 C 的 __init__,打出 enter C。

顺便说一句为什么 B 的 __init__ 会被调用:因为 D 没有定义 __init__,所以会在 MRO 中找下一个类,去查看它有没有定义 __init__,也就是去调用 B 的 __init__

其实这一切逻辑还是很清晰的,关键是理解 super 到底做了什么。

于是,MRO 中类的顺序到底是怎么排的呢?Python’s super() considered super!中已经有很好的解释,我翻译一下:
在 MRO 中,基类永远出现在派生类后面,如果有多个基类,基类的相对顺序保持不变。
关于 MRO 的官方文档参见:The Python 2.3 Method Resolution Order,有一些关于 MRO 顺序的理论上的解释。

最后的最后,提醒大家. 什么 super 啊,MRO 啊,都是针对 new-style class。如果不是 new-style class,就老老实实用父类的类名去调用函数吧。

Useful Hack:Lazy module attribute

昨天晚上师兄在 qq 上和我诉苦,说我们的代码测试起来太不方便了。问题大概出在这段代码:

# motorclient.py
import motor
from settings import mongo_machines, REPLICASET_NAME
from pymongo import ReadPreference

_motorclient = motor.MotorReplicaSetClient(
    ','.join(mongo_machines),
    replicaSet=REPLICASET_NAME,
    readPreference=ReadPreference.NEAREST)

fbt = _motorclient.fbt
reward = _motorclient.fbt_reward
fbt_log = _motorclient.fbt_log

这是我之前写的对 motor 的简单封装,无关代码已经去掉。目的很简单,他们每次要访问数据库只要先 import motorclient,然后用 motorclient.dbname 操作各个数据库就行了。问题在哪里呢?

发现 motorclient 好蛋疼,只要一 import 就必须连接数据库,本地测试每次都要打 mock

这是他的原话,他想在没有配置副本集的本机进行测试,然而 motorclient 只要一 import 就会初始化一个 MotorReplicaSetClient,于是只能 Mock。于是现在有如下需求:

  1. 希望保证现有接口不变
  2. 不要一 import 就初始化
  3. 能非常容易地变成只连接本地数据库而不是副本集

怎么做呢?

我突然想到,David Beazley 的演讲里好像提到了这个概念(关于他的演讲请参考 PyCon2015 笔记)。在 slide 的 150-152 页,他当时想实现的是,import 某个 package 的时候,不直接把 submodules/subpackage 都给 import 进来(因为很耗时间,相当于把所有文件执行一次),而是按需 import,他把这个技巧叫 "Lazy Module Assembly"。我面临的需求和他类似,也要用 "lazy" 的方式加载,只不过针对的是一个 module 里的变量。

上网搜了搜,参考了 SO 上的某答案,完成了 lazy 版,我把它叫做 lazy module attribute

# motorclient.py
mode = None
def set_mode(m):
    global mode
    mode = m

class Wrapper:

    localhost = '127.0.0.1'
    port = 27017
    dbs = {
        'fbt': None,
        'reward': None,
        'fbt_log': None
    }

    def __init__(self, module):
        self.module = module
        self._motorclient = None

    def __getattr__(self, item):
        if item in self.dbs and self._motorclient is None:
            if mode == 'test':
                self._motorclient = motor.MotorClient(host=self.localhost,
                                                      port=self.port)
            else:
                self._motorclient = motor.MotorReplicaSetClient(
                    ','.join(mongo_machines),
                    replicaSet=REPLICASET_NAME,
                    readPreference=ReadPreference.NEAREST)

            self.dbs['fbt'] = self._motorclient.fbt
            self.dbs['reward'] = self._motorclient.fbt_reward
            self.dbs['fbt_log'] = self._motorclient.fbt_log
            self.module.__dict__.update(self.dbs)

        return getattr(self.module, item)

sys.modules[__name__] = Wrapper(sys.modules[__name__])

它的工作流程是这样:
1. 在 import motorclient,会创建一个 Wrapper 类的实例替换掉这个 module 本身,并且把原来的 module object 赋给 self.module别的什么都不做
2. 然后我们在别的文件中访问 motorclient.fbt,进入 Wrapper 实例的 __getattr__ 方法,item='fbt'。因为是初次访问,self._motorclient is None 的条件满足,这时开始初始化变量。
3. 根据全局变量 mode 的值,我们会创建 MotorClient 或是 MotorReplicaSetClient,然后把那几个数据库变量也赋值,并且更新 self.module.__dict__.update(self.dbs)。这个效果就和我们的老版本初始化完全一样了,相当于直接把变量定义写在文件里。
4. 然后调用 getattr(self.module, item),因为我们已经更新过 self.module__dict__,所以能够正常返回属性值。第一次访问至此结束
5. OK,下一次再访问 motorclient.fbt,因为 self._motorclient 已经有值了,所以我们就不再初始化,直接把活交给 self.module 就好了。


下面的内容比较 internal,看不下去的同学就不要看了。。。

本来说到这里就差不多了,不过对 Python import 机制比较了解的同学可能会看出代码中的一个潜在问题。就是这句话: python sys.modules[__name__] = Wrapper(sys.modules[__name__])

为什么 sys.modules 的 key 一定就是 __name__ 呢?下面将追根溯源,证明这一点。

据 Brett Cannon 在《How Import Works》演讲的 slide 第 27 页对 load_module 函数的描述,sys.modules 所用的 key 是 fullname

如果标准库中有类似 module.__name__ = fullname 这种东西,那么我们可以断定 __name__ 就是 fullname。于是苦逼地翻了半天源码,好在终于找到了,有这么一句话python module.__name__ = spec.name

那么这个 Spec 又是什么呢?它实际上是 Python3.4 里才引入的一个类,官方的描述是 "A specification for a module's import-system-related state"。好,就差最后一步了!我又找到了一句代码

spec = spec_from_loader(fullname, self)

没错,Spec 初始化的第一个参数,传入的是 fullname,而它恰恰被赋给了 spec.name。至此,我们的推理终于完成,现在可以肯定地说,sys.modules 的 key 就是这个 module 的 __name__。(实际上这个证明针对的是 Python3.4,不过这种接口肯定是向前兼容的,对于所有版本都成立)

Q.E.D.

当初写 ezcf 的时候其实就遇到过这个问题,只不过没深究。我的代码中有这么一段:

class BaseLoader(_BaseClass):

    def __init__(self, *args, **kwargs):
        pass

    def load_module(self, fullname):
        if fullname in sys.modules:
            mod = sys.modules[fullname]
        else:
            mod = sys.modules.setdefault(fullname, imp.new_module(fullname))

        mod.__file__ = self.cfg_file
        mod.__name__ = fullname  # notice this !!
        mod.__loader__ = self
        mod.__package__ = '.'.join(fullname.split('.')[:-1])

        return mod

当时看到别人都写 self.__name__ = fullname,于是就跟着这么写了,并不明白其中原理。现在终于弄清了,这是 convention,必须得这么写。

不是后记:
感觉好久都没有写文章了啊。。空余时间基本都去刷题了啊。。。跪求 Offer
以及如果读者中有能内推的请联系我(然而并没有什么读者(╥﹏╥)

更好地 partition array

很多算法或者题目里面都有 partition array 这步。所谓 partition array,指的就是把元素按某种条件分开,一部分放前面,另一部分放后面。

比如快速排序,这个条件就是“元素是否小于 pivot”,小于放前面,大于放后面。

大部分人写快排,其中 partition 的一步都喜欢这么写:

def partition(alist, first, last):
    pivot = alist[first]

    left = first + 1
    right = last

    while left <= right:
        while left <= right and alist[left] <= pivote:
            left += 1

        while left <= right and alist[right] >= pivot:
            right += 1

        if left <= right:
            alist[left], alist[right] = alist[right], alist[left]
            left += 1
            right -= 1

    alist[first], alist[right] = alist[right], alist[first]
    return right

弄两根指针 left, right 分别放在 array 的左端和右端,然后 left 右移,right 左移,如果能交换就交换元素,直到 left > right

上面是一个比较标准的 partition 实现,也有一些变体,不过凡是使用 left/right 指针的都是一类。这类实现当然没有问题,但是不优雅,而且难记。

哪里不优雅?仔细看就会发现,代码里有 4 次 left <= right 的判断!而且这 4 次都是必要的。

为什么难记?因为代码多,所以难记!而且稍不注意,最后一步就会写成 alist[first], alist[left] = alist[left], alist[first]

优雅的写法是什么呢?

def partition(alist, first, last):
    pivot = alist[first]        
    i = first

    for j in range(first + 1, last):
        if alist[j] <= pivot:
            i += 1
            alist[i], alist[j] = alist[j], alist[i]

    alist[first], alist[i] = alist[i], alist[first]
    return i

短了这么多,更重要的是再也没有重复的 left <= right 判断了!这段代码也很好理解,j 指针把 array 扫一遍,把 <= pivot 的放到前面,i 用来记录放的地方。停止的时候,i 的位置是 <= pivot 部分的最后一个元素的 index。当然,不要忘记把 pivot 和这个元素交换。

需要注意的地方有两个,首先是初始化:i = first, j = first + 1。还有,每次是先 i += 1,然后再交换i, j元素。

其实这种区别不仅仅在于快排,所有 partition array 都是一样。比如这题:

给出一个整数数组nums和一个整数k。划分数组(即移动数组nums中的元素),使得:

所有小于k的元素移到左边
所有大于等于k的元素移到右边

不优雅的写法这么写:

int partitionArray(vector<int> &nums, int k) {
    int i = 0, j = nums.size() - 1;
    while (i <= j) {
        while (i <= j && nums[i] < k) i++;
        while (i <= j && nums[j] >= k) j--;
        if (i <= j) {
            int temp = nums[i];
            nums[i] = nums[j];
            nums[j] = temp;
            i++;
            j--;
        }
    }
}

还是 4 次比较╮(╯▽╰)╭

优雅的写法:

def partitionArray(nums):        
    i = -1

    for j in range(len(nums)):
        if nums[j] < k:
            i = i + 1
            nums[i], nums[j] = nums[j], nums[i]

当然拿Py和C++比是不合适的,但关键是少了 4 次比较,代码就好看多了。
这一题不存在真实的 pivot,所以我们对快排做了改进,初始化时 i 不再是 0,而是 -1。快排的时候,因为第一次交换时 i = first + 1,所以初始化 i=first 实际上跳过了 pivot 也就是 array[first]。这里因为没有需要跳过的元素,所以 i=-1,然后 j 还是等于 i + 1 也就是 0。其它的没有区别。

熟悉没有真实 pivot 的写法之后,不论按何种规则划分,都只改一行就可以。
比如按照奇偶划分:

def partitionArray(self, nums):
    n = len(nums)
    i = -1

    for j in range(n):
        if nums[j] % 2 == 1:  # changed this line
            i = i + 1
            nums[i], nums[j] = nums[j], nums[i]

其实到这里本来是写完了,但是那天和同学吃饭的时候聊天,又让我意识到优雅写法不光只是优雅而已,有时候只能用它。同学出这题考我:

单链表排序能不能用快排?

我想了想,没有不能用的理由吧,就说能用。他说,不行,因为是单链表,所以不能用快排。我没反应过来,就问为什么单链表不能用快排?他很不解我居然提出这个问题,说你要移动左右两个指针啊,但是单链表你没法把右边的指针往左移动,所以没法用快排。

于是我瞬间懂了,他只知道快排可以用左右指针来写,却不知道另一种写法。原本我只是觉得第二种写法更优雅,没想到在不知道的情况下,居然会误认为单链表没法用快排。

不只是更优雅,而是更好。


top