天天看点

Python 数据分析——matplotlib 绘图函数简介

作者:昌华量化
Python 数据分析——matplotlib 绘图函数简介

本节介绍如何使用matplotlib绘制一些常用的图表。matplotlib的每个绘图函数都有许多关键字参数用来设置图表的各种属性,一般来说,如果读者需要对图表进行某种特殊的设置,可以在绘图函数的说明文档或者matploblib的演示页面中找到相关的说明。

一、对数坐标图

前面介绍过如何使用plot()绘制曲线图,所绘制图表的X-Y轴坐标都是算术坐标。下面我们看看如何在对数坐标系中绘图。

绘制对数坐标图的函数有三个:semilogx()、semilogy()、loglog()。它们分别绘制X轴为对数坐标、Y轴为对数坐标以及两个轴都为对数坐标的图表。

下面的程序使用4种不同的坐标系绘制低通滤波器的频率响应曲线,结果如图1所示。其中,左上图为plot()绘制的算术坐标系,右上图为semilogx()绘制的X轴对数坐标系,左下图为semilogy()绘制的Y轴对数坐标系,右下图为loglog()绘制的双对数坐标系。使用双对数坐标系表示的频率响应曲线通常被称为波特图。

Python 数据分析——matplotlib 绘图函数简介

图1 低通滤波器的频率响应:算术坐标(左上)、X轴对数坐标(右上)、Y轴对数坐标(左下)、双对数坐标(右上)

w = np.linspace(0.1, 1000, 1000)
p = np.abs(1/(1+0.1j*w)) # 计算低通滤波器的频率响应

fig, axes = plt.subplots(2, 2)

functions = ("plot", "semilogx", "semilogy", "loglog")

for ax, fname in zip(axes.ravel(), functions):
func = getattr(ax, fname)
func(w, p, linewidth=2)
ax.set_ylim(0, 1.5)           

二、极坐标图

极坐标系是和笛卡尔坐标系完全不同的坐标系,极坐标系中的点由一个夹角和一段相对中心点的距离表示。下面的程序绘制极坐标图,效果如图2所示。

Python 数据分析——matplotlib 绘图函数简介

图2 极坐标中的圆、螺旋线和玫瑰线

theta = np.arange(0, 2*np.pi, 0.02)

plt.subplot(121, polar=True) ❶
plt.plot(theta, 1.6*np.ones_like(theta), linewidth=2) ❷
plt.plot(3*theta, theta/3, "--", linewidth=2)

plt.subplot(122, polar=True)
plt.plot(theta, 1.4*np.cos(5*theta), "--", linewidth=2)
plt.plot(theta, 1.8*np.cos(4*theta), linewidth=2)
plt.rgrids(np.arange(0.5, 2, 0.5), angle=45) ❸
plt.thetagrids([0, 45]) ❹           

❶调用subplot()创建子图时通过设置polar参数为True,创建一个极坐标子图。❷然后调用plot()在极坐标子图中绘图。也可以使用polar()直接创建极坐标子图并在其中绘制曲线。

❸rgrids()设置同心圆栅格的半径大小和文字标注的角度。因此右图中的虚线圆圈有三个,半径分别为0.5、1.0和1.5,这些文字沿着45°线排列。❹thetagrids()设置放射线栅格的角度,因此右图中只有两条放射线栅格线,角度分别为0°和45°。

三、柱状图

柱状图用其每根柱子的长度表示值的大小,它们通常用来比较两组或多组值。下面的程序从文件中读入中国人口的年龄的分布数据(人口分布数据由维基百科提供,仅供参考,不保证正确性),并使用柱状图比较男性和女性的年龄分布,效果如图3所示。

Python 数据分析——matplotlib 绘图函数简介

图3 中国男女人口的年龄分布图

data = np.loadtxt("china_population.txt")
width = (data[1,0] - data[0,0])*0.4 ❶
plt.figure(figsize=(8, 4))
c1, c2 = plt.rcParams['axes.color_cycle'][:2]
plt.bar(data[:,0]-width, data[:,1]/1e7, width, color=c1, label=u"男") ❷
plt.bar(data[:,0], data[:,2]/1e7, width, color=c2, label=u"女") ❸
plt.xlim(-width, 100)
plt.xlabel(u"年龄")
plt.ylabel(u"人口(千万)")
plt.legend()           

读入的数据中,第一列为年龄,它将作为柱状图的横坐标。❶首先计算柱状图中每根柱子的宽度,因为要在每个年龄段上绘制两根柱子,因此柱子的宽度应该小于年龄段的二分之一。这里以年龄段的0.4倍作为柱子的宽度。

❷调用bar()绘制男性人口分布的柱状图。它的第一个参数为每根柱子的左边缘的横坐标,为了让男性和女性的柱子以年龄刻度为中心,这里让每根柱子左侧的横坐标为“年龄减去柱子的宽度”。bar()的第二个参数为每根柱子的高度,第三个参数指定所有柱子的宽度。当第三个参数为序列时,可以为每根柱子指定宽度。

❸绘制女性人口分布的柱状图,这里以年龄为柱子的左边缘横坐标,因此女性和男性的人口分布图以年龄刻度为中心。由于bar()不自动修改颜色,因此程序中通过color参数设置两个柱状图的颜色。

四、散列图

使用plot()绘图时,如果指定样式参数为只绘制数据点,那么所绘制的就是一幅散列图。例如:

plt.plot(np.random.random(100), np.random.random(100), "o")           

但是这种方法所绘制的点无法单独指定颜色和大小。scatter()所绘制的散列图可以指定每个点的颜色和大小。下面的程序演示了scatter()的用法,效果如图4所示。

Python 数据分析——matplotlib 绘图函数简介

图4 可指定点的颜色和大小的散列

plt.figure(figsize=(8, 4))
x = np.random.random(100)
y = np.random.random(100)
plt.scatter(x, y, s=x*1000, c=y, marker=(5, 1),
alpha=0.8, lw=2, facecolors="none")
plt.xlim(0, 1)
plt.ylim(0, 1)           

scatter()的前两个参数是两个数组,分别指定每个点的X轴和Y轴的坐标。s参数指定点的大小,其值和点的面积成正比,可以是单个数值或数组。

c参数指定每个点的颜色,也可以是数值或数组。这里使用一维数组为每个点指定了一个数值。通过颜色映射表,每个数值都会与一个颜色相对应。默认的颜色映射表中蓝色与最小值对应,红色与最大值对应。当c参数是形状为(N, 3)或(N, 4)的二维数组时,则直接表示每个点的RGB颜色。

marker参数设置点的形状,可以是一个表示形状的字符串,或是表示多边形的两个元素的元组,第一个元素表示多边形的边数,第二个元素表示多边形的样式,取值范围为0、1、2、3。0表示多边形,1表示星形,2表示放射形,3表示忽略边数显示为圆形。

最后,通过alpha参数设置点的透明度,lw参数设置线宽,它是linewidth的缩写。facecolors参数为"none"表示散列点没有填充色。

五、图像

imread()和imshow()提供了简单的图像载入和显示功能。imread()可以从图像文件读入数据,得到一个表示图像的NumPy数组。它的第一个参数是文件名或文件对象,format参数指定图像类型,如果省略则由文件的扩展名决定图像类型。对于灰度图像,它返回一个形状为(M, N)的数组;对于彩色图像,它返回形状为(M, N, C)的数组。其中M为图像的高度,N为图像的宽度,C为3或4,表示图像的通道数。下面的程序从lena.jpg中读入图像数据,效果如图5所示。所得到的数组img是一个形状为(393, 512, 3)的单字节无符号整数数组。这是因为通常所使用的图像采用单字节分别保存每个像素的红、绿、蓝三个通道的分量:

Python 数据分析——matplotlib 绘图函数简介

图5 用imread()和imshow()显示图像

img = plt.imread("lena.jpg")
print img.shape, img.dtype
(393, 512, 3) uint8           

下面使用imshow()显示img所表示的图像:

❶imshow()可以用来显示imread()所返回的数组。如果数组是表示多通道图像的三维数组,则每个像素的颜色由各个通道的值决定。

❷imshow()所绘制图表的Y轴的正方向是从上往下的。如果设置imshow()的origin参数为"lower",则所显示图表的原点在左下角,但是整个图像就上下颠倒了。

❸如果三维数组的元素类型为浮点数,则元素值的取值范围为0.0到1.0,与颜色值0到255对应。超过这个范围可能会出现颜色异常的像素。下面的例子将数组img转换为浮点数组并用imshow()进行显示,由于数值范围超过了0.0~1.0,因此颜色显示异常。

❹而取值在0.0~1.0的浮点数组和原始图像完全相同。

❺使用clip()将超出范围的值限制在取值范围之内,可以使整个图像变亮。

❻如果imshow()的参数是二维数组,则使用颜色映射表决定每个像素的颜色。这里显示图像中的红色通道,它是一个二维数组。其显示效果比较吓人,因为默认的图像映射将最小值映射为蓝色、将最大值映射为红色。可以使用colorbar()将颜色映射表在图表中显示出来。

❼通过imshow()的cmap参数可以修改显示图像时所采用的颜色映射表,使用名为copper的颜色映射表显示图像的红色通道。

img = plt.imread("lena.jpg")
fig, axes = plt.subplots(2, 4, figsize=(11, 4))
fig.subplots_adjust(0, 0, 1, 1, 0.05, 0.05)

axes = axes.ravel()

axes[0].imshow(img)                        ❶
axes[1].imshow(img, origin="lower")        ❷
axes[2].imshow(img * 1.0)                  ❸
axes[3].imshow(img / 255.0)                ❹
axes[4].imshow(np.clip(img / 200.0, 0, 1)) ❺

axe_img = axes[5].imshow(img[:, :, 0])     ❻
plt.colorbar(axe_img, ax=axes[5])

axe_img = axes[6].imshow(img[:, :, 0], cmap="copper") ❼
plt.colorbar(axe_img, ax=axes[6])

for ax in axes:
ax.set_axis_off()           

颜色映射表是一个ColorMap对象,matplotlib中已经预先定义了很多颜色映射表,可以通过下面的语句找到这些颜色映射表的名字:

import matplotlib.cm as cm
cm._cmapnames[:5]
['Spectral', 'copper', 'RdYlGn', 'Set2', 'summer']           

使用imshow()可以显示任意的二维数据,例如下面的程序使用图像直观地显示了二元函数,效果如图6所示。

Python 数据分析——matplotlib 绘图函数简介

图6 使用imshow()可视化二元函数

y, x = np.ogrid[-2:2:200j, -2:2:200j]
z = x * np.exp( - x**2 - y**2) ❶

extent = [np.min(x), np.max(x), np.min(y), np.max(y)] ❷

plt.figure(figsize=(10,3))
plt.subplot(121)
plt.imshow(z, extent=extent, origin="lower") ❸
plt.colorbar()
plt.subplot(122)
plt.imshow(z, extent=extent, cmap=cm.gray, origin="lower")
plt.colorbar()           

❶首先通过数组的广播功能计算出表示函数值的二维数组z,注意它的第0轴表示Y轴、第1轴表示X轴。❷然后将X、Y轴的取值范围保存到extent列表中。❸将extent列表传递给imshow()的extent参数,这样图表的X、Y轴的刻度标签将使用extent列表指定的范围。

六、等值线图

还可以使用等值线图表示二元函数。所谓等值线,是指由函数值相等的各点连成的平滑曲线。等值线可以直观地表示二元函数值的变化趋势,例如等值线密集的地方表示函数值在此处的变化较大。matplotlib中可以使用contour()和contourf()描绘等值线,它们的区别是contourf()所得到的是带填充效果的等值线。下面的程序演示了这两个函数的用法,效果如图7所示:

Python 数据分析——matplotlib 绘图函数简介

图7 用contour(左)和contourf(右)描绘等值线图

y, x = np.ogrid[-2:2:200j, -3:3:300j] ❶
z = x * np.exp( - x**2 - y**2)

extent = [np.min(x), np.max(x), np.min(y), np.max(y)]

plt.figure(figsize=(10,4))
plt.subplot(121)
cs = plt.contour(z, 10, extent=extent) ❷
plt.clabel(cs) ❸
plt.subplot(122)
plt.contourf(x.reshape(-1), y.reshape(-1), z, 20) ❹           

❶为了更清楚地区分X轴和Y轴,这里让它们的取值范围和等分次数均不相同。这样所得到的数组z的形状为(200, 300),它的第0轴对应Y轴,第1轴对应X轴。

❷调用contour()绘制数组z的等值线图,第二个参数为10表示将整个函数的取值范围等分为10个区间,即其所显示的等值线图中将有9条等值线。和imshow()一样,可以使用extent参数指定等值线图的X轴和Y轴的数据范围。❸contour()所返回的是一个QuadContourSet对象,将它传递给clabel(),为其中的等值线标上对应的值。

❹调用contourf()绘制带填充效果的等值线图。这里演示了另一种设置X、Y轴取值范围的方法。它的前两个参数分别是计算数组z时所使用的X轴和Y轴上的取样点,这两个数组必须是一维数组或是形状与数组z相同的数组。

如果需要对散列点数据绘制等值线图,可以先使用scipy.interpolate模块中提供的插值函数将散列点数据插值为网格数据。

还可以使用等值线绘制隐函数曲线。所谓隐函数,是指在一个方程中,若令x在某一区间内取任意值时总有相应的y满足此方程,则可以说方程在该区间上确定了x的隐函数y,如隐函数x2+y2-1=0表示一个单位圆。

显然无法像绘制一般函数那样,先创建一个等差数组表示变量x的取值点,然后计算出数组中每个x所对应的y值。可以使用等值线解决这个问题,显然隐函数的曲线就是值等于0的那条等值线。下面的程序绘制函数:

f(x,y)=(x2+y2 )4-(x2-y2 )2

在f(x,y)=0和f(x,y)-0.1=0时的曲线,效果如图8(左)所示。

Python 数据分析——matplotlib 绘图函数简介

图8 使用等值线绘制隐函数曲线(左),获取等值线数据并绘图(右)

y, x = np.ogrid[-1.5:1.5:200j, -1.5:1.5:200j]
f = (x**2 + y**2)**4 - (x**2 - y**2)**2

plt.figure(figsize=(9, 4))
plt.subplot(121)
extent = [np.min(x), np.max(x), np.min(y), np.max(y)]
cs = plt.contour(f, extent=extent, levels=[0, 0.1],    ❶
colors=["b", "r"], linestyles=["solid", "dashed"], linewidths=[2, 2])

plt.subplot(122)
for c in cs.collections: ❷
data = c.get_paths()[0].vertices
plt.plot(data[:,0], data[:,1],
color=c.get_color()[0],  linewidth=c.get_linewidth()[0])           

❶在调用contour()绘制等值线时,可以通过levels参数指定等值线所对应的函数值,这里设置levels参数为[0, 0.1],因此最终将绘制两条等值线。通过colors、linestyles、linewidths等参数可以分别指定每条等值线的颜色、线型以及线宽。

仔细观察图8(左)会发现,表示隐函数f(x,y)=0的蓝色实线并不是完全连续的,在图的中间部分它由许多孤立的小段构成。因为等值线在原点附近无限靠近,所以无论对函数f的取值空间如何进行细分,总是会有无法分开的地方,最终造成了图中的那些孤立的细小区域,而表示隐函数f(x,y)-0.1=0的红色虚线则是闭合且连续的。

❷从等值线集合cs中找到表示等值线的路径,并使用plot()将其绘制出来,效果如图8(右)所示。

contour()返回一个QuadContourSet对象,其collections属性是一个等值线列表,每条等值线用一个LineCollection对象表示:

print cs
cs.collections
<matplotlib.contour.QuadContourSet instance at 0x057EFC10>
<a list of 2 mcoll.LineCollection objects>           

每个LineCollection对象都有它自己的颜色、线型、线宽等属性,注意这些属性所获得结果的外面还有一层包装,要获得其第0个元素才是真正的配置:

print cs.collections[0].get_color()[0]
print cs.collections[0].get_linewidth()[0]
[ 0. 0. 1. 1.]
2           

在前面的章节介绍过LineCollection对象是一组曲线的集合,因此它可以表示蓝色实线那样由多条线构成的等值线。它的get_paths()方法获得构成等值线的所有路径,本例中蓝色实线所表示的等值线由42条路径构成:

len(cs.collections[0].get_paths())
42           

路径是一个Path对象,通过它的vertices属性可以获得路径上所有点的坐标:

path = cs.collections[0].get_paths()[0]
path.vertices
array([[-0.08291457, -0.98938936],
[-0.09039269, -0.98743719],
[-0.09798995, -0.98513674],
...,
[-0.05276382, -0.99548781],
[-0.0678392 , -0.99273907],
[-0.08291457, -0.98938936]])           

七、四边形网格

pcolormesh(X, Y, C)绘制由X、Y和C三个数组定义的四边形网格。这三个数组是二维数组,X和Y的形状相同,C的形状可以和X、Y相同,也可以比它们少一行一列。每个四边形的4个顶点的X轴坐标由X中上下左右相邻的4个元素决定,Y轴坐标由Y中对应的4个元素决定。四边形的颜色由C中对应的元素以及颜色映射表决定。

在下面的例子中,X和Y的形状都是(2, 3),其中有两组上下左右相邻的4个元素,定义两个四边形的4个顶点:

第一个四边形的顶点第二个四边形的顶点

================= ================

(0, 0), (1, 0.2) (1, 0.2), (2, 0)

(0, 1), (1, 0.8) (1, 0.8), (2, 1)

每个四边形的填充颜色与Z中的一个元素对应:

X = np.array([[0, 1, 2],
[0, 1, 2]])
Y = np.array([[0, 0.2, 0],
[1, 0.8, 1]])
Z = np.array([[0.5, 0.8]])           

下面将X和Y平坦化之后用plot()绘制出这些顶点的坐标,然后调用pcolormesh()绘制这两个四边形。与左边的四边形对应的颜色映射值为0.5,与右边的四边形对应的颜色映射值为0.8,因此一个显示为蓝色,另一个显示为红色。

plt.plot(X.ravel(), Y.ravel(), "ko")
plt.pcolormesh(X, Y, Z)
plt.margins(0.1)           
Python 数据分析——matplotlib 绘图函数简介

图9 演示pcolormesh()绘制的四边形及其填充颜色

在下面的例子中,使用pcolormesh()绘制复数平面上的坐标变换。在图10中,左侧的图表显示s平面上的矩形区域,右侧的图表显示通过公式

Python 数据分析——matplotlib 绘图函数简介

坐标变换之后的网格,左侧中的矩形被变换成右侧同颜色的四边形。由于axes[2]和axes[3]中的网格由近4万个四边形组成,为了在输出SVG图像时提高绘图速度,这里将rasterized参数设置为True,这些四边形将作为一幅点阵图像输出到SVG图像中。

Python 数据分析——matplotlib 绘图函数简介

图10 使用pcolormesh()绘制复数平面上的坐标变换

def make_mesh(n):
x, y = np.mgrid[-10:0:n*1j, -5:5:n*1j]

s = x + 1j*y
z = (2 + s) / (2 - s)
return s, z

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes = axes.ravel()
for ax in axes:
ax.set_aspect("equal")

s1, z1 = make_mesh(10)
s2, z2 = make_mesh(200)
axes[0].pcolormesh(s1.real, s1.imag, np.abs(s1))
axes[1].pcolormesh(z1.real, z1.imag, np.abs(s1))
axes[2].pcolormesh(s2.real, s2.imag, np.abs(s2), rasterized=True)
axes[3].pcolormesh(z2.real, z2.imag, np.abs(s2), rasterized=True)           

还可以在极坐标中使用pcolormesh()绘制网格,下面的例子使用mgrid[]创建极坐标中的等间隔网格,然后在projection为polar的子图中绘制这个网格:

def func(theta, r):
y = theta * np.sin(r)
return np.sqrt(y*y)

T, R = np.mgrid[0:2*np.pi:360j, 0:10:100j]
Z = func(T, R)

ax=plt.subplot(111, projection="polar", aspect=1.)
ax.pcolormesh(T, R, Z, rasterized=True)           
Python 数据分析——matplotlib 绘图函数简介

图11 使用pcolormesh()绘制极坐标中的网格

八、三角网格

在工业工程设计与分析中,经常将分析对象使用三角网格离散化,然后用有限元法进行模拟。在matplotlib中提供了下面的三角网格绘制函数:

·triplot():绘制三角网格的边线。

·tripcolor():与pcolormesh()类似,绘制填充颜色的三角网格。

·tricontour()和tricontourf():绘制三角网格的等高线。

diffusion.txt是使用FiPy对二维稳态热传导问题进行有限元模拟的结果。该文件分为三个部分:

·以#points开头的部分是一个形状为(N_points, 2)的数组,保存N_points个点的坐标。

·以#triangles开头的部分是一个形状为(N_triangles, 3)的数组,保存每个三角形三个顶点在points数组中的下标。

·以#values开头的部分是一个形状为(N_triangles, 1)的数组,保存每个三角形对应的温度。

下面的程序将这些数据读入data字典:

with open("diffusion.txt") as f:
data = {"points":[], "triangles":[], "values":[]}
values = None
for line in f:
line = line.strip()
if not line:
continue
if line.startswith("#"):
values = data[line[1:]]
continue
values.append([float(s) for s in line.split()])

data = {key:np.array(data[key]) for key in data}           

然后就可以调用trip*(),用三角形网格显示目标区域的温度,结果如图12所示。

Python 数据分析——matplotlib 绘图函数简介

图12 使用tripcolor()和tricontour()绘制三角网格和等值线

❶tripcolor()的参数从左到右分别为各点的X轴坐标、Y轴坐标、三角形顶点下标、标量数组。标量数组中的每个值可以与每个顶点对应,也可以与每个三角形对应。在本例中由于values的长度与triangles的第0轴长度相同,因此每个值与三角形相对应。若标量数组的长度与顶点数相同,则每个三角形对应的值由其三个顶点的平均值决定。

❷调用triplot()绘制所有三角形的边线。❸调用tricontour()绘制等高线。由于要求标量数组与三角形顶点相对应,而本例中标量数组与三角形对应,因此先计算每个三角形的重心坐标Xc和Yc,这样values中的每个值就可以与每个三角形的重心对应。在调用tricontour()时没有传递三角形顶点下标信息,这时会调用matplotlib自带的三角化算法计算出每个三角形对应的顶点。

X, Y = data["points"].T
triangles = data["triangles"].astype(int)
values = data["values"].squeeze()

fig, ax = plt.subplots(figsize=(12, 4.5))
ax.set_aspect("equal")

mapper = ax.tripcolor(X, Y, triangles, values, cmap="gray") ❶
plt.colorbar(mapper, label=u"温度")

plt.triplot(X, Y, triangles, lw=0.5, alpha=0.3, color="k") ❷

Xc = X[triangles].mean(axis=1)
Yc = Y[triangles].mean(axis=1)
plt.tricontour(Xc, Yc, values, 10) ❸           

九、箭头图

使用quiver()可以用大量的箭头表示矢量场。下面的程序显示的梯度场,结果如图13所示。vec_field(f, x, y)近似计算函数f在x和y处的偏导数。

Python 数据分析——matplotlib 绘图函数简介

图13 用quiver()绘制矢量场

quiver()的前5个参数中,X、Y是箭头起点的X轴和Y轴坐标,U、V是箭头方向和大小的矢量,C是箭头对应的值。

def f(x, y):
return x * np.exp(- x**2 - y**2)

def vec_field(f, x, y, dx=1e-6, dy=1e-6):
x2 = x + dx
y2 = y + dy
v = f(x, y)
vx = (f(x2, y) - v) / dx
vy = (f(x, y2) - v) / dy
return vx, vy

X, Y = np.mgrid[-2:2:20j, -2:2:20j]
C = f(X, Y)
U, V = vec_field(f, X, Y)
plt.quiver(X, Y, U, V, C)
plt.colorbar();
plt.gca().set_aspect("equal")           

此外,quiver()还提供许多参数来配置箭头的大小和方向:

·箭头的长度由scale和scale_units决定。其中scale为数值,表示箭头的缩放尺度,而scale_units为箭头的长度单位,可选单位有'width'、'height'、'dots'、'inches'、'x'、'y'、'xy'等。其中'width'、'height'为子图的宽和高,'dots'和'inches'以点和英寸为单位,'x'、'y'、'xy'则以数据坐标系的X轴、Y轴或单位矩形的对角线为单位。箭头的长度按照“UV矢量的长度 * 箭头的长度单位 / 缩放尺度”计算。例如,如果scale为2,scale_units为'x',而UV矢量的长度为3,则对应的箭头的长度为1.5个X轴的单位长度。

·width、headwidth、headlength和headaxislength等参数决定箭头的杆部分粗细、箭头部分的大小以及长度,而units参数决定这些参数的单位,可选值与scale_units相同。这些参数的含义如图14所示。

Python 数据分析——matplotlib 绘图函数简介

图14 quiver箭头的各个参数的含义

·pivot参数决定箭头旋转的中心,可以为'tail'、'middle'、'tip'等值,在图14中使用灰色圆点表示这些旋转点。

·angles参数决定箭头的方向。正方形可能由于X轴和Y轴的缩放尺度不同而显示为长方形,因此方向有两种计算方式:'uv'和'xy'。其中'uv'只采用U和V的值计算方向,因此若U和V的值相同,则方向为45度;而'xy'在使用U和V计算角度时考虑X轴和Y轴的缩放尺度。

下面通过两个例子帮助读者理解这些参数的用法,如图15所示。首先绘制了一条参数曲线,然后沿着该曲线绘制了40个等分曲线的箭头,箭头的方向表示箭头处曲线的切线方向,颜色表示箭头所在处参数的大小。计算部分留给读者自行分析,下面仔细分析这些参数是如何决定箭头的大小和方向的。

Python 数据分析——matplotlib 绘图函数简介

图15 使用箭头表示参数曲线的切线方向

箭头的长度和其他尺寸的单位由scale_units和units决定,在本例中均为'dots',即以像素点为单位。dx和dy为描述箭头的矢量,长度为1,将scale参数设置为1.0/arrow_size,这样所有箭头的长度均为arrow_size个像素点。箭杆的宽度由width参数指定,本例中的宽度为1个像素。而headwidth、headlength和headaxislength等参数决定箭头部分的宽度、长度以及箭头与箭杆接触部分的长度,这些参数为对应长度与箭杆宽度的比例系数。在本例中,由于箭杆宽度为1个像素,因此箭头宽度为arrow_size * 0.5个像素,而箭头部分的长度和箭头的长度相同,因此图中的箭头没有箭杆部分。

由于子图的X轴和Y轴的缩放比例不同,因此设置angles参数为"xy",这样箭头的方向才能与曲线的切线方向相同。

n = 40
arrow_size = 16
t = np.linspace(0, 1, 1000)
x = np.sin(3*2*np.pi*t)
y = np.cos(5*2*np.pi*t)
line, = plt.plot(x, y, lw=1)

lengths = np.cumsum(np.hypot(np.diff(x), np.diff(y)))
length = lengths[-1]
arrow_locations = np.linspace(0, length, n, endpoint=False)
index = np.searchsorted(lengths, arrow_locations)
dx = x[index + 1] - x[index]
dy = y[index + 1] - y[index]
ds = np.hypot(dx, dy)
dx /= ds
dy /= ds
plt.quiver(x[index], y[index], dx, dy, t[index],
units="dots", scale_units="dots",
angles="xy", scale=1.0/arrow_size, pivot="middle",
edgecolors="black", linewidths=1,
width=1, headwidth=arrow_size*0.5,
headlength=arrow_size, headaxislength=arrow_size,
zorder=100)
plt.colorbar()
plt.xlim([-1.5, 1.5])
plt.ylim([-1.5, 1.5])           

还可以用quiver()绘制起点和终点的箭头集合。下面的例子绘制神经网络结构示意图,效果如图16所示。为了让箭头能够连接两个神经节点,将scale_units设置为"xy",将angles设置为"xy",并且将scale设置为1。这样箭头的长度就为箭头对应的矢量在数据空间中的长度。

Python 数据分析——matplotlib 绘图函数简介

图16 使用quiver()绘制神经网络结构示意图

levels = [4, 5, 3, 2]
x = np.linspace(0, 1, len(levels))

for i in range(len(levels) - 1):
j = i + 1
n1, n2 = levels[i], levels[j]
y1, y2 = np.mgrid[0:1:n1*1j, 0:1:n2*1j]
x1 = np.full_like(y1, x[i])
x2 = np.full_like(y2, x[j])
plt.quiver(x1, y1, x2-x1, y2-y1,
angles="xy", units="dots", scale_units="xy",
scale=1, width=2, headlength=10,
headaxislength=10, headwidth=4)

yp = np.concatenate([np.linspace(0, 1, n) for n in levels])
xp = np.repeat(x, levels)
plt.plot(xp, yp, "o", ms=12)
plt.gca().axis("off")
plt.margins(0.1, 0.1)           

十、三维绘图

mpl_toolkits.mplot3d模块在matplotlib的基础上提供了三维作图的功能。由于它使用matplotlib的二维绘图功能实现三维图形的绘制工作,因此绘图速度有限,不适合用于大规模数据的三维绘图。

下面是绘制三维曲面的程序,程序的输出如图17所示。

Python 数据分析——matplotlib 绘图函数简介

​图17 使用mplot3D绘制的三维曲面图

import mpl_toolkits.mplot3d ❶

x, y = np.mgrid[-2:2:20j, -2:2:20j] ❷
z = x * np.exp( - x**2 - y**2)

fig = plt.figure(figsize=(8, 6))
ax = plt.subplot(111, projection='3d') ❸
ax.plot_surface(x, y, z, rstride=2, cstride=1, cmap = plt.cm.Blues_r) ❹
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")           

❶首先载入mplot3d模块,matplotlib中三维绘图相关的功能均在此模块中定义。❷使用mgrid创建X-Y平面的网格并计算网格上每点的高度z。由于绘制三维曲面的函数要求其X、Y和Z轴的数据都用相同形状的二维数组表示,因此这里不能使用ogrid创建。和之前的imshow()不同,数组的第0轴可以表示X和Y轴中的任意一个,在本例中第0轴表示X轴,第1轴表示Y轴。

❸在当前图表中创建一个子图,通过projection参数指定子图的投影模式为"3d",这样subplot()将返回一个用于三维绘图的Axes3D子图对象。

投影模式

投影模式决定了点从数据坐标转换为屏幕坐标的方式。可以通过下面的语句获得当前有效的投影模式的名称:

>>> from matplotlib import projections
>>> projections.get_projection_names()
['3d', 'aitoff', 'hammer', 'lambert', 'mollweide', 'polar', 'rectilinear']           

只有在载入mplot3d模块之后此列表中才会出现'3d'投影模式。'aitoff'、'hammer'、'lambert'、'mollweide'等均为地图投影,'polar'为极坐标投影,'rectilinear'则是默认的直线投影模式。

❹调用Axes3D对象的plot_surface()绘制三维曲面图。其中参数x、y、z都是形状为(20, 20)的二维数组。数组x和y构成了X-Y平面上的网格,而数组z则是网格上各点在曲面上的取值。通过cmap参数指定值和颜色之间的映射,即曲面上各点的高度值与其颜色的对应关系。rstride和cstride参数分别是数组的第0轴和第1轴的下标间隔。对于很大的数组,使用较大的间隔可以提高曲面的绘制速度。

除了绘制三维曲面之外,Axes3D对象还提供了许多其他的三维绘图方法。请读者在官方网站查看各种三维绘图的演示程序。