天天看點

python3實作softmax + 函數曲線繪制

繪制softmax函數曲線 + python3實作

import numpy as np

# 實作方法1
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x), axis=0)

# 是想方法2
def softmax2(x):
    """Compute softmax values for each sets of scores in x."""
    x = np.array(x)
    x = np.exp(x)

    if x.ndim == 1:
        sumcol = sum(x)
        for i in range(x.size):
            x[i] = x[i]/float(sumcol)
    if x.ndim > 1:
        sumcol = x.sum(axis = 0)
        for row in x:
            for i in range(row.size):
                row[i] = row[i]/float(sumcol[i])
    return x


data1 = [1]
data2 = [0]
data3 = [1,2,3,4]
data4 = [[1,5,7,3],[2,3,4,5]]
data5 = [[[1,2],[5,3]],[[9,18],[4,22]]]

print(softmax(data1) ,"\n")
print(softmax(data2) ,"\n")
print(softmax(data3) ,"\n")
print(softmax(data4) ,"\n")
print(softmax(data5) ,"\n")

print(softmax2(data1) ,"\n")
print(softmax2(data2) ,"\n")
print(softmax2(data3) ,"\n")
print(softmax2(data4) ,"\n")
print(softmax2(data5) ,"\n")
           

輸出

[ 1.] 

[ 1.] 

[ 0.0320586   0.08714432  0.23688282  0.64391426] 

[[ 0.26894142  0.88079708  0.95257413  0.11920292]
 [ 0.73105858  0.11920292  0.04742587  0.88079708]] 

[[[  3.35350130e-04   1.12535162e-07]
  [  7.31058579e-01   5.60279641e-09]]

 [[  9.99664650e-01   9.99999887e-01]
  [  2.68941421e-01   9.99999994e-01]]] 

[ 1.] 

[ 1.] 

[ 0.0320586   0.08714432  0.23688282  0.64391426] 

[[ 0.26894142  0.88079708  0.95257413  0.11920292]
 [ 0.73105858  0.11920292  0.04742587  0.88079708]] 

Traceback (most recent call last):
  File "/Users/hupeng/Documents/mlNotes/softmax.py", line 41, in <module>
    print(softmax2(data5) ,"\n")
  File "/Users/hupeng/Documents/mlNotes/softmax.py", line 21, in softmax2
    row[i] = row[i]/float(sumcol[i])
TypeError: only length-1 arrays can be converted to Python scalars
           

最後一行報錯了, 因為方法數組越界了

softmax 是将一些列資料歸一化到0-1之間, 是以運用場景中必定會是一維資料或是變向的一維資料

python3實作softmax + 函數曲線繪制
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-10, 10, 200)
y = softmax(x)
print(x,y)
plt.plot(x,y)
plt.show()
           

繪制一下曲線,結果如圖

python3實作softmax + 函數曲線繪制