首頁  >  文章  >  後端開發  >  如何使用Python繪製常見的激活函數曲線?

如何使用Python繪製常見的激活函數曲線?

PHPz
PHPz轉載
2023-04-26 12:01:071735瀏覽

準備工作:下載numpy、matplotlib、sympy

pip install numpy matplotlib sympy

查找對應庫的文檔:

numpy文檔matplotlib文檔sympy文檔

寫程式碼的時候發現vscode不會格式化我的python?查了一下原來還要安裝flake8和yapf,一個是檢查程式碼規格工具一個是格式化工具,接著進行設定setting.json

"python.linting.flake8Enabled": true, // 规范检查工具
"python.formatting.provider": "yapf", // 格式化工具
"python.linting.flake8Args": ["--max-line-length=248"], // 设置单行最长字符限制
"python.linting.pylintEnabled": false, // 关闭pylint工具

準備工作完成, 接下來就看看怎麼寫程式碼

第一步新建一個py檔案

先把激活函數的函數表達式寫出來,這有兩種方式,如果只是單純的得出計算結果,其實用numpy就夠了,但還要自己去求導,那就要用sympy寫出函數式了。

sympy表達函數的方式是這樣的:

from sympy import symbols, evalf, diff
# 我们先要定义自变量是什么,这边按需求来,这是文档的例子有两个变量
x, y = symbols('x y')
# 然后我们写出函数表达式
expr = x + 2*y
# 输出看一下是什么东西
expr # x + 2*y
# 接着就要用我们定义的函数了
expr.evalf(subs={x: 10, y: 20}) # 50.000000
# 再对我们的函数求导
diff(expr, x, 1) # 对x进行求导得出结果 1,这也是表达式

diff為sympy的求導函數

sympy.core.function.diff(f, *symbols , **kwargs)

接著我們定義激活函數的表達式

def sigmoid():
    """
    定义sigmoid函数
    """
    x = symbols('x')
    return 1. / (1 + exp(-x))
def tanh():
    """
    定义tanh函数
    """
    x = symbols('x')
    return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
def relu():
    """
    定义ReLU函数
    """
    x = symbols('x')
    return Piecewise((0, x < 0), (x, x >= 0))
def leakyRelu():
    """
    定义Leaky ReLu函数
    """
    x = symbols(&#39;x&#39;)
    return Piecewise((0.1 * x, x < 0), (x, x >= 0))
def softMax(x: np.ndarray):
    """
    定义SoftMax函数\n
    """
    exp_x = np.exp(x)
    print(exp_x, np.sum(exp_x))
    return exp_x / np.sum(exp_x)
def softmax_derivative(x):
    """
    定义SoftMax导数函数\n
    x - 输入x向量
    """
    s = softMax(x)
    return s * (1 - s)

然後再定義一個求導函數

def derivate(formula, len, variate):
    """
    定义函数求导
      formula:函数公式
      len:求导次数
      variate:自变量
    """
    return diff(formula, variate, len)

這邊有一個問題,為什麼其他函數都是一個,而softMax函數有兩個,一個是softMax函數定義,一個是其導函數定義?

我們來看看softMax函數的樣子

如何使用Python繪製常見的激活函數曲線?

softMax函數分母需要寫累加的過程,使用numpy.sum無法透過sympy去求導(有人可以,我不知道為什麼,可能是使用方式不同,知道的可以交流一下)而使用sympy.Sum或者sympy.summation又只能從i到n每次以1為單位累加

例如:假定有個表達式為m**x (m的x次方)sympy.Sum(m**x, (x, 0, 100))則結果為m**100 m**99 m**98 … m**1,而我定義的ndarray又是np.arange(-10, 10, 0.05),這就無法達到要求,就無法進行求導。

所以就寫兩個函數,一個是原函數定義,一個是導函數定義,之前也說了,如果是求值的話,其實只用numpy就可以完成。

至此,所有函數以及導函數就被我們定義好了

第二步使用matplotlib繪製曲線

首先,我們得知道matplotlib有什麼吧

matplotlib主要有Figure、Axes、Axis、Artist。我理解為figure就是畫布,我們在繪製圖形之前得準備好畫布;axes和axis翻譯都是軸的意思,但是axes應該是坐標軸,axis是坐標軸中的某一個軸;artist為其他可加入的元素

如果要繪製一張簡單的圖可以這樣做

x = np.linspace(0, 2, 100)  # Sample data.

# Note that even in the OO-style, we use `.pyplot.figure` to create the Figure.
fig, ax = plt.subplots(figsize=(5, 2.7), layout=&#39;constrained&#39;)
ax.plot(x, x, label=&#39;linear&#39;)  # Plot some data on the axes.
ax.plot(x, x**2, label=&#39;quadratic&#39;)  # Plot more data on the axes...
ax.plot(x, x**3, label=&#39;cubic&#39;)  # ... and some more.
ax.set_xlabel(&#39;x label&#39;)  # Add an x-label to the axes.
ax.set_ylabel(&#39;y label&#39;)  # Add a y-label to the axes.
ax.set_title("Simple Plot")  # Add a title to the axes.
ax.legend()  # Add a legend.

然後我們準備繪製我們的函數曲線了

plt.xlabel(&#39;x label&#39;) // 两种方式加label,一种为ax.set_xlabel(面向对象),一种就是这种(面向函数)
plt.ylabel(&#39;y label&#39;)

加完laben之後,我考慮了兩種繪製方式,一是把所有曲線都繪製在一個figure裡面,但是分為不同的axes

使用subplot函數可以把figure分成2行2列的axes

plt.subplot(2, 2, 1, adjustable=&#39;box&#39;) # 1行1列
plt.subplot(2, 2, 2, adjustable=&#39;box&#39;) # 1行2列

第二個是透過輸入函數名稱繪製指定的函數

do = input( &#39;input function expression what you want draw(sigmoid, tanh, relu, leakyRelu, softMax)\n&#39; )

得到輸入之後

 try:
        plt.xlabel(&#39;x label&#39;)
        plt.ylabel(&#39;y label&#39;)
        plt.title(do)
        if (do == &#39;softMax&#39;):
            plt.plot(num, softMax(num), label=&#39;Softmax&#39;)
            plt.plot(num, softmax_derivative(num), label=&#39;Softmax Derivative&#39;)
        else:
            plt.plot(
                num,
                [eval(f&#39;{do}()&#39;).evalf(subs={symbols("x"): i}) for i in num])
            plt.plot(num, [
                derivate(eval(f&#39;{do}()&#39;), 1, &#39;x&#39;).evalf(subs={symbols(&#39;x&#39;): i})
                for i in num
            ])

        plt.tight_layout()
        plt.show()
    except TypeError:
        print(
            &#39;input function expression is wrong or the funciton is not configured&#39;
        )

這就完活了,附一張賣家秀

如何使用Python繪製常見的激活函數曲線?

以上是如何使用Python繪製常見的激活函數曲線?的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:yisu.com。如有侵權,請聯絡admin@php.cn刪除