逆变换采样是一种生成遵循任意分布的随机值的方法。由于某种原因,这种方法从未在任何流行的科学图书馆中实施。而且由于经常需要使用它,所以我决定实现为我自己执行的功能,而不是每次都手动计算它。

我要评论的内容:


我是在重新发明轮子吗?我进行了彻底搜索,但找不到类似的内容。
这是解决此问题的正确方法吗?我在SO:链接上看到了以下代码。在那里,PDF通过离散分布来近似。也许,这样更好。我不知道。
sympy中存在很多问题,因此我的功能看起来像一堆补丁和变通办法,以使其正常运行。也许有更优雅,更正确的方法来解决这些缺陷。
缺少输入内容。 PDF的数量是无限的。我可能会想念的。
类型提示。我写的正确吗?使用sympy时,对象的类型非常混乱。
代码样式。


代码:inverse_transform.py

import operator
from typing import Iterator

import numpy as np
import sympy as sym
from scipy.special import lambertw
from sympy.functions.elementary.piecewise import ExprCondPair


def sample(pdf: sym.Function,
           *,
           size: int) -> np.array:
    """
    Generates random values following the given distribution
    :param pdf: input Probability Density Function (PDF)
    :param size: number of generated values
    """
    if not isinstance(pdf, sym.Piecewise):
        raise ValueError("PDF must be constructed by sympy.Piecewise")

    pdf_functions = map(operator.attrgetter('func'),
                        pdf.atoms(sym.Function))
    if sym.re in pdf_functions:
        error_message = ("Using sympy.Abs or sympy.re is not supported "
                         "due to not implemented computing of their integrals "
                         "within SymPy. Split the relevant expression.")
        raise NotImplementedError(error_message)

    # The following is used in order to prevent an error
    # when using PDF in a form of, for example, x**-2.5.
    # For more details see:
    # https://stackoverflow.com/questions/50543587/integrating-piecewise-with-irrational-exponent-gives-error
    pdf = sym.nsimplify(pdf)

    x = pdf.free_symbols.pop()
    y = sym.Dummy('y')

    cdf = sym.integrate(pdf, (x, -sym.oo, y))
    # The following is used in order to prevent
    # long erroneous polynomials
    # when calculating PDF in a form of, for example,  x**-2.5
    # Beware that this will add too much precision. Bug.
    # Issue submitted: https://github.com/sympy/sympy/issues/14787
    cdf = cdf.evalf()

    eq = sym.Eq(x, cdf)

    # TODO: Use solveset when it will be able to deal with LambertW
    # With default rational == True, there will be an error
    # as 'solve' doesn't play along with Piecewise.
    # Related issue: https://github.com/sympy/sympy/issues/12024
    inverse_solutions = sym.solve(eq, y, rational=False)
    # Sometimes, especially for exponents,
    # there are garbage solutions with imaginary parts:
    # https://github.com/sympy/sympy/issues/9973
    inverse_solutions = filter(is_real, inverse_solutions)

    # As, for some reason, 'solve' returns a list of Piecewise's,
    # it's necessary to collect them back together.
    # Related issue: https://github.com/sympy/sympy/issues/14733
    inverse_cdf = recreate_piecewise(inverse_solutions)
    # If inverse CDF will contain LambertW function,
    # we must change its branch. For more details, see:
    # https://stackoverflow.com/questions/49817984/sympy-solve-doesnt-give-one-of-the-solutions-with-lambertw
    functions = map(operator.attrgetter('func'),
                    inverse_cdf.atoms(sym.Function))
    if sym.LambertW in functions:
        inverse_cdf = replace_lambertw_branch(inverse_cdf)
        # This is to prevent LambertW giving ComplexWarning after lambdifying
        inverse_cdf = sym.re(inverse_cdf)

    max_value = cdf.args[-1][0]

    # Warnings can happen with exponents in PDF:
    # https://github.com/sympy/sympy/issues/14789
    lambda_function = sym.lambdify(args=x,
                                   expr=inverse_cdf,
                                   modules=[{'LambertW': lambertw}, 'numpy'])
    return lambda_function(np.random.uniform(high=max_value,
                                             size=size))


def is_real(expression: sym.Expr) -> bool:
    """Checks if expression doesn't contain imaginary part with sympy.I"""
    return sym.I not in expression.atoms()


def recreate_piecewise(functions: Iterator[ExprCondPair]) -> sym.Piecewise:
    """
    Collects Piecewise from list of unsorted Piecewise's,
    ignoring parts with NaNs.
    Solution for the issue: https://github.com/sympy/sympy/issues/14733
    See also question on SO:
    https://stackoverflow.com/questions/50428912/how-to-get-sorted-exprcondpairs-in-a-piecewise-function-that-was-obtained-from
    """
    def remove_nans(expression_condition: ExprCondPair) -> ExprCondPair:
        return expression_condition.args[0]

    def right_hand_number(solution: ExprCondPair) -> sym.S:
        return solution[1].args[1]

    solutions = sorted(map(remove_nans, functions),
                       key=right_hand_number)
    return sym.Piecewise(*solutions)


def to_lower_lambertw_branch(*args) -> sym.Function:
    """
    Wraps the first argument from a given list of arguments
    as a lower branch of LambertW function.
    :return: lower LambertW branch
    """
    return sym.LambertW(args[0], -1)


def replace_lambertw_branch(expression: sym.Expr) -> sym.Expr:
    """
    Replaces upper branch of LambertW function with the lower one.
    For details of the bug see:
    https://stackoverflow.com/questions/49817984/sympy-solve-doesnt-give-one-of-the-solutions-with-lambertw
    Solution is based on the 2nd example from:
    http://docs.sympy.org/latest/modules/core.html?highlight=replace#sympy.core.basic.Basic.replace
    :return: expression with replaced LambertW branch by a lower one
    """
    return expression.replace(sym.LambertW,
                              to_lower_lambertw_branch)



用法示例:
我将绘制结果以便给出更好的主意:

import matplotlib.pyplot as plt
import sympy as sym

import inverse_transform

x = sym.Symbol('x')
f = sym.Piecewise((0, x < 0.),
                  (1, x <= 1.),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




f = sym.Piecewise((0, x < 4.3),
                  (1, x < 12.9),
                  (5, x <= 13.5),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




shift = 1.5
f = sym.Piecewise((0., x <= shift),
                  ((x - shift) * sym.exp(-(x - shift)), x <= 13.5),
                  (0., True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




f = sym.Piecewise((0, x < 6.5),
                  (97.25 / (25 + x**2) , x < 10.5),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




f = sym.Piecewise((0, x < 0.4),
                  (x ** -2.35, x < 50),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




f = sym.Piecewise((0, x < 6.5),
                  (sym.exp(-x/3.5) , x < 10.5),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




f = sym.Piecewise((0, x < -2),
                  (sym.exp(x/0.25) , x < 0),
                  (sym.exp(-x/0.25) , x < 2),
                  (0, True))
plt.hist(inverse_transform.sample(f, size=10**6),
         bins=100)
plt.show()




评论

这是一个不错的代码,您是否进行了性能测试并将其与此处的答案进行比较stackoverflow.com/q/21100716/1391441?另外,如果sympy导致了很多问题(仍然是吗?),是否可以用其他软件包或香草Python代替它?

谢谢!我没有执行任何性能测试,因为我对代码的运行速度非常满意。关于SymPy,从代码中所有已描述的错误中,我只知道一个已修复的错误,对其他错误一无所知,但看起来没有任何变化。由于我的方法使用符号数学,因此恐怕没有其他类似的Python库。也许某些其他语言可以更好地完成此任务,例如MATLAB,Mathematica或R,我们可以从Python中调用它们吗?我认为值得调查。

在重塑轮子部分上,如果我理解正确,可以使用scipy的scipy.stats.rv_continuous来设置连续概率密度函数。然后,您可以使用random_state从该密度函数获取样本。 scipy.stats中也有此类的离散变体。

@agtoever这看起来很有趣!当我有更多空闲时间时,我将进行更多调查。谢谢分享!

#1 楼


如果不需要自定义分段概率(这对我来说很奇怪),而想使用许多SciPy分布之一,则可以使用它们的预定义百分比函数(逆CDF)。仅当您确实需要创建新的发行版时,您才能发现我的下一个观察结果有用。另外,正如agtoever所提到的,您可能想利用rv_continuous中的scipy
使用SymPy是一把两刃剑。象征性地定义和计算CDF和逆样本可能是准确的,但在语法或CPU时间方面可能会花费更多。尽管如此,CDF的数值计算也很昂贵,因此您的方法可能适用于这种分段函数。
对于sample的每次调用,都不应重新计算CDF。您应该将CDF的计算从PDF拆分为其他函数(并具有def sample(cdf: [...])。通常,我的建议与Bob叔叔的建议相同,即,使功能尽可能短。 (例如,您还可以使用一种称为if not isinstance(pdf, sym.Piecewise):的方法从raise NotImplementedError(error_message)validate提取行。
提取CDF和/或验证方法后,它们就可以测试了,您可能想为其添加测试。也许您还可以提取并测试逆的点解。可悲的是,测试随机抽样本身可能是不可能的(将随机失败)。
我不是很熟练的优化,但是在寻找另一种解决方案时,在我看来,从给定的PDF导出CDF,然后求解逐点逆数,您有正确的主意。另一种解决方案是对CDF进行数值查找,并出于性能原因可以选择使用Chebyshev多项式来逼近CDF。 />
虽然我找不到该解决方案的更多重大问题,但我也不确定它是否正确(您永远不会使用软件),但是测试可能会增加人们对代码的信心。 br />

评论


\ $ \ begingroup \ $
感谢您的贡献!很多有用的信息在这里。我不能保证会在不久的将来仔细查看所提供的某些链接,但是我敢肯定,此答案将对将来的访问者有所帮助。
\ $ \ endgroup \ $
–乔治
20年8月28日在15:04

\ $ \ begingroup \ $
感谢您的反馈:)老实说,我期望这样的回答(9年以上)不会对您个人有所帮助,但对其他人有所帮助。出于好奇,您是否需要分段概率?
\ $ \ endgroup \ $
– Danuker
20年8月28日在17:45

\ $ \ begingroup \ $
实际上只有2岁以上,但是是的,最近我一直在研究完全不同的主题。我当时的研究工作与银矮星的银河种群有关,这些分段的概率来自它们的年龄和速度的某些分布。不幸的是,我记不清他们背后的原因是分段的,以及是否可以避免。
\ $ \ endgroup \ $
–乔治
20年8月28日在20:43

#2 楼

考虑到这个问题已经有将近2年没有答案了,我想我应该把2c放进去,然后添加一些东西。

我是在重新发明轮子吗?我进行了彻底搜索,但找不到类似的内容
这是解决此问题的正确方法吗?

我认为在堆栈溢出/数学上比对代码审阅会更好地提出这些问题。验证证明是他们的工作重点,我们更多地是在解决您的代码问题(这可能是为什么这个问题尚未得到答案的观点已经超过3K,并且使用了将近2年的时间?)。
看看您的代码,看起来不错,但是我要指出一些问题(在sample中)-
   if not isinstance(pdf, sym.Piecewise):
        raise ValueError("PDF must be constructed by sympy.Piecewise")

如果出现此错误-sympy是否可以捕获并处理?还是死了?如果无法正确捕获,则只打印消息并返回None将是更好的方法。这是更标准的IMO模式。
再次使用raise NotImplementedError(error_message)-如果您在代码中到达无法逾越的位置-除非在上面处理该消息,否则没有任何意义。只需记录/打印错误(没有继续的地方,对吗?)并结束程序。
关于replace_lambertw_branchto_lower_lambertw_branch函数-它们仅使用一次。我了解需要引入清晰度,但是当只使用一次(并且是一行)时,不要用所有的脚手架创建一个完整的功能。这是不必要的代码,迫使您的代码读者四处寻找试图解决的问题。
对于remove_nansright_hand_number的内部函数,我也可以说同样的话,尽管它们之间的距离更近了,但仍然需要代码阅读器停止流,然后转到其他地方查找正在发生的事情。
它们也只使用一次-当lambda起作用时,无需创建整个支架。
添加更多注释-如果代码的用户/读者不清楚lambda的功能,则可以改进lambda中变量的命名以使其更清楚,或者在lambda上方添加注释以解释“为什么”-但是
继续说明注释-它们分散在整个代码中。这是不好的,有两个原因。一个-它打断了对代码流的理解-并撒谎了。如果它们都很重要-应该将它们移到一个单独的文档中,解释您的方法以及执行方法的原因。这些“为什么”注释不不属于代码-因为代码由于某种原因而改变的那一刻-代码随即出现注释-会使查看您代码的所有人感到困惑。
很少有善意的程序员尝试修复代码以匹配注释(很少见,但我已经看到了)。
否则,其余代码似乎还不错。也许其他人可以补充我写的内容。我希望这会有所帮助。

评论


\ $ \ begingroup \ $
感谢您的评论!自从我使用这段代码已经很长时间了,所以我不记得所有的细节了。我不记得是为什么我提出了ValueError而不是让SymPy提出它的确切原因,但是,我想,这是因为SymPy的错误消息有些胡言乱语,无法帮助用户解决问题。我不同意返回None,打印一条消息而不是引发错误,并且有选择地立即结束程序。最终这是一个设计选择,但是我个人会坚持提出例外,因为...
\ $ \ endgroup \ $
–乔治
20年5月11日在12:20

\ $ \ begingroup \ $
... 1)与打印的消息不同,异常消息进入stderr,2)对于我的函数,无异常是没有意义的,除非出现问题,否则用户始终会排除一个值数组; 3)结束程序应该在最高级别完成,我可以想象我的函数被嵌入到某些代码中,该代码只是按照给定分布生成值的几种选择之一;如果我的方法失败,则高级程序可能会回退到例如某些Monte Carlo值生成器,调用sys.exit()会使事情复杂化...
\ $ \ endgroup \ $
–乔治
20年5月11日在12:21

\ $ \ begingroup \ $
...尽管,现在,我再次看我的代码,我发现如果我编写一些辅助函数以使用sympy.Abs扩展表达式,则可以避免NotImplementedError。我同意谈到助手功能,其中有些是不必要的。具有is_real的表达式可以用生成器表达式重写,并且replace_lambertw_branch可以完全删除。
\ $ \ endgroup \ $
–乔治
20年5月11日在12:21

\ $ \ begingroup \ $
...我还记得考虑注释的问题,通常我在代码中都反对它们,但是在这种情况下,我认为它们在代码中的存在是有道理的,因为根本不清楚为什么相应的行代码是必需的。我也可以按照您的建议将注释移到其他位置,但是对我来说,直接在代码中编辑注释会更容易。
\ $ \ endgroup \ $
–乔治
20年5月11日在12:23