遍历给定模块中给定类的子类

时间:2020-03-05 18:47:47  来源:igfitidea点击:

在Python中,给定模块X和类Y,我如何迭代或者生成模块X中存在的所有Y子类的列表?

解决方案

回答

这是一种实现方法:

import inspect

def get_subclasses(mod, cls):
    """Yield the classes in module ``mod`` that inherit from ``cls``"""
    for name, obj in inspect.getmembers(mod):
        if hasattr(obj, "__bases__") and cls in obj.__bases__:
            yield obj

回答

给定模块foo.py

class foo(object): pass
class bar(foo): pass
class baz(foo): pass

class grar(Exception): pass

def find_subclasses(module, clazz):
    for name in dir(module):
        o = getattr(module, name)

        try: 
             if issubclass(o, clazz):
             yield name, o
        except TypeError: pass

>>> import foo
>>> list(foo.find_subclasses(foo, foo.foo))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>)]
>>> list(foo.find_subclasses(foo, object))
[('bar', <class 'foo.bar'>), ('baz', <class 'foo.baz'>), ('foo', <class 'foo.foo'>), ('grar', <class 'foo.grar'>)]
>>> list(foo.find_subclasses(foo, Exception))
[('grar', <class 'foo.grar'>)]

回答

我可以建议Chris AtLee和zacherates的答案都不能满足要求吗?
我认为这种修改zacerates的答案更好:

def find_subclasses(module, clazz):
    for name in dir(module):
        o = getattr(module, name)
        try:
            if (o != clazz) and issubclass(o, clazz):
                yield name, o
        except TypeError: pass

我不同意给定答案的原因是,第一个不产生属于给定类的遥远子类的类,第二个不包含给定类。

回答

尽管Quamrana的建议效果很好,但我还是建议对它进行一些可能的改进,使其更符合Python风格。他们依靠使用标准库中的inspect模块。

  • 我们可以使用inspect.getmembers()来避免getattr调用
  • 使用inspect.isclass()可以避免try / catch

通过这些,我们可以根据需要将整个事情简化为单个列表理解:

def find_subclasses(module, clazz):
    return [
        cls
            for name, cls in inspect.getmembers(module)
                if inspect.isclass(cls) and issubclass(cls, clazz)
    ]