Skip to content

Python 日常必备技巧备忘录(下)

发布于  at 11:21 AM

TOC

Open TOC

惰性属性

__getattr__

如果类中定义了 __getattr__ ,那么每当访问该类对象的属性,而且实例字典里又找不到这个属性时,系统就会触发 __getattr__ 方法。

class LazyRecord:

    def __init__(self):
        self.exists = 5

    def __getattr__(self, name):
        value = f'Value for {name}'
        setattr(self, name, value)
        return value

if __name__ == "__main__":
    data = LazyRecord()
    print("Before: ", data.__dict__)
    print("Foo: ", data.foo)
    print("After: ", data.__dict__)

执行结果:

Before:  {'exists': 5}
Foo:  Value for foo
After:  {'exists': 5, 'foo': 'Value for foo'}

如果要实现惰性的(lazy,也指按需的)数据访问机制,而这份数据又没有 schema,那么通过 __getattr__ 来做就相当合适。

__getattribute__

假设我们现在还需要验证数据库系统的事务状态。也就是说,用户每次访问某属性时,我们都要确保数据库里面的那条记录依然有效,而且相应的事务也处在开启状态。

这个需求没办法通过 __getattr__ 实现,因为一旦对象的实例字典里包含了这个属性,那么程序就会直接从字典获取,而不会再触发 __getattr__

应对这种比较高级的用法,Python 的 object 还提供了另一个 hook,叫作 _getattribute__。只要访问对象中的属性,就会触发这个特殊方法,即便这项属性已经在 __dict__ 字典里,系统也还是会执行 __getattribute__ 方法。

于是,我们可以在这个方法里面检测全局的事务状态,这样就能对每一次属性访问操作都进行验证了。

注意: 这种写法开销很大,而且会降低程序的效率,但有的时候确实值得这么做。

如果要访问的属性根本就不应该存在,那么可以在 __getattr__ 方法里面拦截。无论是 __getattr__ 还是 __getattribute__,都应该抛出标准的 AttributeError 来表示属性不存在或不适合存在的情况。

class ValidatingRecord:
    """验证记录类:在每次属性访问时都会触发 __getattribute__"""

    def __init__(self):
        self.exists = 5

    def __getattribute__(self, name):
        # 拦截所有属性访问,包括已存在的属性
        print(f"Called __getattribute__({name!r})")

        try:
            value = super().__getattribute__(name)
            print(f"Found {name!r}, returning {value!r}")
            return value
        except AttributeError:
            # 属性不存在时,动态创建
            value = f"Value for {name}"
            print(f"Setting {name!r} to {value!r}")
            setattr(self, name, value)
            return value

if __name__ == "__main__":
    data = ValidatingRecord()
    print("exists: ", data.exists)  # 访问已存在属性
    print("First Foo: ", data.foo)  # 第一次访问 foo
    print("Second After: ", data.foo)  # 第二次访问 foo,仍会触发 __getattribute__

运行结果:

Called __getattribute__('exists')
Found 'exists', returning 5
exists:  5
Called __getattribute__('foo')
Setting 'foo' to 'Value for foo'
First Foo:  Value for foo
Called __getattribute__('foo')
Found 'foo', returning 'Value for foo'
Second After:  Value for foo

如果要使用本对象的普通属性,那么应该通过 super()(也就是object类)来使用,而不要直接使用,避免无限递归。

需要准确计算浮点数,使用 decimal

Python 语言很擅长操纵各种数值。它的整数类型实际上可以表示任意尺寸的整型数据,它的双精度浮点数类型遵循 IEEE 754 规范。

假如我们要给国际长途电话计费。通话时间用分和秒来表示,这项数据是已知的(例如 3 分 42 秒)​。通话费率也是固定的,例如从美国打给南极洲的电话,每分钟 1.45 美元。现在要计算这次通话的具体费用。

>>> rate = 1.45
>>> seconds = 3*60 +42
>>> cost = rate * seconds / 60
>>> print(cost)
5.364999999999999

这个答案比正确答案(5.365)少了 0.000000000000001。

浮点数必须表示成 IEEE 754 格式,所以采用浮点数算出的结果可能跟实际结果稍有偏差。

这样的计算应该用 Python 内置的 decimal 模块所提供的 Decimal 类来做。这个类默认支持 28 位小数,如果有必要,还可以调得更高。

>>> from decimal import Decimal
>>> rate = Decimal('1.45')
>>> seconds = Decimal(3*60 + 42)
>>> cost = rate * seconds / Decimal(60)
>>> print(cost)
5.365

如果费率较低,例如每分钟 0.05 美元,通话 5 秒钟的总费用。

>>> rate = Decimal('0.05')
>>> seconds = Decimal('5')
>>> small_cost = rate * seconds / Decimal(60)
>>> print(small_cost)
0.004166666666666666666666666667

如果四舍五入,就会变成 0:

>>> print(round(small_cost, 2))
0.00

Decimal 类提供了 quantize 函数,可以根据指定的舍入方式把数值调整到某一位。

>>> from decimal import ROUND_UP
>>> rounded = cost.quantize(Decimal('0.01'), rounding=ROUND_UP)
>>> print(f'Rounded {cost} to {rounded}')
Rounded 5.365 to 5.37

Decimal 可以很好地应对小数点后面 数位有限 的值(也就是定点数,fixed point number)​。但对于小数位无限的值(例如 1/3)来说,依然会有误差。

如果想表示精度不受限的有理数,那么可以考虑使用内置的 fractions 模块里面的 Fraction 类。

deque 实现生产者-消费者队列

先进先出的(first-in, first-out,FIFO)队列,这种队列也叫作 生产者-消费者队列(producer–consumer queue)或生产-消费队列。

由于 list 增加元素时可能需要扩容,从头部删除元素则需要将后续元素前移,性能开销很大(平方增长)。

Python 内置的 collections 模块里面有个 deque 类,可以解决这个问题。这个类所实现的是 双向队列(double-ended queue)​,从头部执行插入或尾部执行删除操作,都只需要固定的时间,所以它很适合充当 FIFO 队列。

from collections import deque
queue = deque()

# 生产者函数
def producer():
    for i in range(5):
        item = f"任务-{i}"
        queue.append(item)
        print(f"生产者:生产了 {item}")

# 消费者函数
def consumer():
    while queue:
        item = queue.popleft()
        print(f"消费者:消费了 {item}")

producer()
consumer()

考虑用 bisect 搜索已排序的序列

Python 内置的 bisect 模块可以更好地搜索 有序列表。其中的 bisect_left 函数,能够迅速地对任何一个有序的序列执行 二分搜索

>>> from bisect import bisect_left
>>> data = list(range(10**5))
>>> index = bisect_left(data, 91234)
>>> assert index == 91234
>>> index = bisect_left(data, 91234.56)
>>> assert index == 91235

通过 warnings 提醒开发者 API 已经发生变化

假设我们要将这个简单的距离计算函数改成一个更加完善的函数:

def print_distance(speed, duration):
    distance = speed * duration
    print(f'{distance} miles')

修改成:

CONVERSIONS = {
    'mph': 1.60934 / 3600 * 1000, # m/s
    'hours': 3600, # seconds
    'miles': 1.60934 * 1000, # m
    'meters': 1, # m
    'm/s': 1, # m/s
    'seconds': 1, #s
}

def convert(value, units):
    rate = CONVERSIONS[units]
    return rate * value

def localize(value, units):
    rate = CONVERSIONS[units]
    return value / rate

def print_distance(speed, duration, *, speed_units='mph', time_units='hours', distance_units='miles'):
    norm_speed = convert(speed, speed_units)
    norm_duration = convert(duration, time_units)
    norm_distance = norm_speed * norm_duration
    distance = norm_speed * norm_duration
    distance = localize(norm_distance, distance_units)
    print(f"{distance} {distance_units}")


if __name__ == "__main__":
    print_distance(1000, 3, speed_units='meters', time_units='seconds')

我们一方面想让采用旧写法的那些代码继续运转,另一方面又想鼓励开发者及早改用新写法来指明运算时所用的单位。

可以通过 Python 内置的 warnings 模块解决。warnings 模块发出的是警告,而不是那种带有 Error 字样的异常(exception)​,异常主要针对计算机而言,目标是让程序能够自动处理相关的错误​,而警告则是写给开发者的,目标是与他们沟通,告诉对方应该如何正确地使用这个 API。

import warnings

def require(name, value, default):
    if value is not None:
        return value
    warnings.warn(f'{name} will be required soon, update your code', DeprecationWarning, stacklevel=3)
    return default

def print_distance(speed, duration, *, speed_units=None, time_units=None, distance_units=None):
    speed_units = require('speed_units', speed_units, 'mph')
    time_units = require('time_units', time_units, 'hours')
    distance_units = require('distance_units', distance_units, 'miles')

    norm_speed = convert(speed, speed_units)
    norm_duration = convert(duration, time_units)
    norm_distance = norm_speed * norm_duration
    distance = norm_speed * norm_duration
    distance = localize(norm_distance, distance_units)
    print(f"{distance} {distance_units}")


if __name__ == "__main__":
    print_distance(1000, 3, speed_units='meters', time_units='seconds')

这样就可以看到警告:

d:\python_tips\main.py:40: DeprecationWarning: distance_units will be required soon, update your code
  print_distance(1000, 3, speed_units='meters', time_units='seconds')
1.8641182099494205 miles
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自小谷的随笔

上一篇
槽边往事:准备迎接后疫情时代
下一篇
Python 日常必备技巧备忘录(上)