分类问题中的算法

    李晓明

    

    

    

    现在,人工智能很热。人工智能应用中有些很基本的需求,分类即为其中最常见的一种。分类关心的是如何将一个对象归到已知类别中。例如,人可以分为少年、青年、中年、老年四个类别,商品可以分为廉价和高档两个类别,学生在课堂上的表现可分为积极踊跃、沉稳细致、不以为然三个类别,网购者可分为随性和理性两种类别,等等。在日常生活中,这种分类主要凭感觉。而如果用计算机来做,则每一种类别都需要通过一些数据特征予以刻画,每一个对象或者说个体都是通过一个“数据点”来表示,如后面要讲的一个网购的例子,其数据特征就用了“每月网购次数”和“每次平均花费”。一个每月网购10次、每次平均花费120元的人则用数据点(10,120)来表示。

    不难体会到,把一个个体归到预设的类别中,有些比较容易,有些则很不容易。鉴于分类在现实中有广泛的应用,人们发明了多种算法,应对各种不同的情形。本文介绍两个最基本的算法——K近邻(KNN)算法和支持向量机(SVM)算法。

    这两个算法的基本思路都很简单直观。KNN的实现容易理解,因而人工智能的教材里一般都会有它。SVM的常规实现需要较深的数学基础,因而一般初等人工智能教材都没有它。不过,由于SVM实在很重要,最近也看到有尝试在中学人工智能教材中介绍它[1],只是在介绍了算法基本思想后用了一句“这个问题可以用优化方法求解,但具体过程已经超出了同学们现阶段的知识储备,等到同学们以后学习了相关的数学工具后,就可以求解了”就结束了。本文将从一个不同的视角描述一个实现SVM的算法,该算法从效率上讲没有常规的SVM算法高,但用到的数学知识在中学生可理解的范畴内,他们应该能够编程实现,因而适合中学的教学。

    一般而言,设计分类算法的目的是实现一个所谓“分类器”(程序)。分类器的实现通常都是基于一批已知类别的数据,形成某些规则,来做未知类别对象的类别判断,图1是一个概念图。

    实现一个分类器的基础是一个预先给定的类别集合C={C1,C2,...,Cm}和一批已知类别的样本数据集D={d1,d2,...,dn}。不同分类算法的区别一般体现在所形成的分类规则上。

    类别集合C通常是人们根据需要或经验事先确定的,有一定现实含义,如本文开始时提到的那些。样本数据集D怎么做到类别已知的呢?通常是采用人工标注给定,也就是事先找来一批有代表性的数据,请有背景知识的人一一给打上类别标签。这项工作一方面很重要,对后续自动分类的质量有基础性影响,另一方面在互联网经济中的需求越来越普遍,因而形成了一种新职业——数据标注员。

    在分类问题中,一个核心的概念是两个数据点之间的距离。所谓判断一个数据点该属于哪个类,本质上就是看它离哪个类的已知数据点更近。而“距离”在不同的应用背景下可能有不同的定义。我们以二维数据空间为例,给出三种常见的距离定义,如图2所示。

    设(x1,y1)和(x2,y2)为两个数据点p1和p2的坐标,欧式距离、曼哈顿距离和余弦相似度①分别为:

    其中,前面两个都有“数值越小越接近(相似)”的含义,而余弦相似度定义在区间[0,1],越大越相似。若余弦相似度为1,意味着两个数据点同在一条通过原点的直线上。体会这几个定义的含义,读者对它们适用的不同场合能有些直觉认识。

    一般而言,KNN和SVM处理的数据点都可以是高维的,用多于三个特征分量来表示一个对象的特征。本文为突出要点,只考虑二维的情形,于是有如图3所示的视觉形象,便于解释有关细节。

    图3(a)示意在二维数据空间中有两种已经标注(分别用圆和三角表示)类型的样本数据。它们大致分布在空间中两个不同的区域,任何两个数据点之间都可以谈论某种距离。我们特别注意到,同类数据之间的距离不一定就比异类之间的小。图3(b)示意出现了一个未知类别的数据(x),它应该属于哪一类呢?

    ● KNN算法基本思路

    针对样本数据D,KNN算法采用了一种可以说是体现“近朱者赤,近墨者黑”和“少数服从多数”原则的直截了当的思路。它一一计算待分类数据x与样本数据集D中所有数据的距离,然后取其中最小的K个(也就是“KNN”中的K,而NN表示“最近的邻点”),看它们分别属于哪一个类,判定x应该属于K中出现较多的那个类。

    采用什么距离定义和具体应用有关,为和后面的例子对应,下面的算法描述采用了余弦相似度作为距离。

    ● KNN分类算法(如表1)

    ● 算法运行例

    假设我们考虑对网购者的分类,用“每月网购次数”和“每次平均花费”两个量来刻画每一个用户,要看一个人是“随性”(S)还是“理性”(R)②。有一个已经人工标好类型的样本数据如图4(a)所示,现在有一个用户每月网购5次,平均花费40元,即她的数据是x=(5,40)。问她是属于随性网购者还是理性网购者?

    采用KNN算法对这个用户分类,采用余弦相似度作为距离度量,首先算得x=(4,50)与16个已知数据的距离,如图4(b)第1列所示。为方便查看,我们把那些距离按照与1接近的程度的排序放在表中最右边一列。(注意,对余弦距离而言,越接近1表示越“近”)

    现在,如果取K=1,看到离x最近的(6,45)是“S”,于是x应该被分类为“S”。如果取K=3,离x最近的3个都是“S”,如果取K=5,离x最近的5个里有4个“S”1个“R”,等等。你觉得应该认为x是随性网购者还是理性网购者呢?显然,认为这个x是一个随性网购者比较合理。

    ● 算法性质的分析

    这是一个正确的算法吗?如果考虑的是2-分类,即类别数为2,且K为奇数,KNN算法总会有一个输出,建议x应该属于哪一类。因此不存在停机或收敛之类的问题。问题可能在于它给出的建议到底有多靠谱。这样,除了给出x当属于哪个类别外,还可以给出一个概率,即在TOP-K中,占优类型的数据在整个K中的占比。例如,在上面的例子中,K=3,這个占比就是100%,K=5,这个占比就是80%。如果类别数大于2,则还需要有一个方法来做“平手消解”,当某两类在TOP-K中有相同出现时决定取哪一个。

    一个分类器的质量常常用“准确度”(accuracy)指标来评价。假设一共有p个测试数据x1,x2,…,xp,对它们分完类后人工一一做检查。用r表示分类错误的个数,准确度就是p/(p+r)。

    这个算法的效率如何呢?

    KNN分类算法的计算复杂度与样本集大小(n)有关,与样本属性的维数有关。在我们讨论的二维情况下,从算法描述中可以看到,计算余弦距离时间复杂度是O(n)。算法第3行找出n个相似度中的TOP-K,一般算法是O(k*n),采用适当的数据结构可以做到O(k*logn)。

    在学过了算法园地专栏前面十多个算法后,现在学KNN算法,读者可能会产生一种十分不一样的感觉。首先,这种算法看起来好简单,直截了当好理解,其逻辑与前面讨论过的那些相比要浅显许多。其次,给出的分类结果到底有多靠谱,一般来说是没法证明的。事实上,像分类这样的问题,现实中经常就是没有客观标准。但为了能推进研究,人们通常会准备一些有代表性的测试数据集(benchmark)用于比较不同方法的效果。当然,最终效果如何,只能通过实际应用检验了。

    ● SVM概念与算法目标

    下面讨论SVM,它有一个听起来很高深的名字“支持向量机”。我们不被它困扰,来看具体是什么意思。还是用图3的数据,画出SVM分类概念的示意图,如图5(a)所示。

    可以看到,图中除了数据点,还有将两类数据点分开的直线。SVM就是要基于样本数据点,算出一条那样的直线(y=ax+b),从而可对拟分类数据依照它是在直线的左右来赋予类别,这个示例就是左边为“圆”,右边为“三角”。不过,读者可能马上注意到图5(a)中有两条直线,它们对于待分类数据点(x)分类的结论是不同的,于是就有了哪个更好的问题。SVM怎么考虑呢?SVM要求一条“最优的”直线。

    什么叫“最优的”直线?SVM采用的观点是离它最近的数据点尽量远。形象地看,就是在两类数据之间“最窄”处的中线。图5(b)是一个示意,根据两类数据点的情况,我们分别画了一个外包凸多边形(这种多边形在计算几何学中叫“凸包”),这样它们之间的“通道”也看得很清楚了③。如果我们能确定两个凸包上距离最近的两个点(不一定都恰好在顶点上),做连接它们的线段,再做该线段的垂直平分线,就得到了SVM的结果,如图5(c)所示。

    下面给出的SVM算法就是求这样一条直线的算法。读者能感到与前面的KNN很不同,那里是基于样本数据直接对数据点(x)做分类。不过,读者也能意识到,一旦有了这样一条直线,判断一个数据(x)应该属于哪一类就是一件平凡的事情,用不着专门说了。

    ● 一种SVM分类器算法

    下面描述的算法相对比较宏观(如表2),有利于读者把握整体概念。其中的细节在后面做进一步阐述。另外,所提到的距离此时均为欧几里德距离。

    下面来讨论其中的要点。首先,理解这个算法的细节(包括编程实现)所需的主要数学知识为平面向量的知识,这在高中新教材中已有覆盖,见参考文献[2]的第6章。具体包括平面上两个点的距离公式、一个点到一条直线的距离公式(由此得到点到线段的距离公式)、根据一个点的坐标和斜率确定直线y=ax+b的方法等。下面以算法中的最后一步(5)为例,展示处理这类问题的样式。已知由(x1,y1)和(x2,y2)确定的线段,要确定垂直平分线y=ax+b中的参数a和b。

    算法的第2、3、4步只涉及距离的计算和排序,不用赘述。算法的第1步,从一个平面数据点集得到它的凸包。这是计算几何学中的一种基本运算,人们研究出了许多简单易懂的算法,陈道蓄教授在本专栏上一期,也就是2020年第21期上介绍了其中两种[3],在此也不赘述,有兴趣的读者可自行查阅学习。

    ● 算法的性质分析

    先看计算复杂性。若第1步采用[3]中介绍的第二个算法,复杂性是O(nlogn)。第2步和第3步计算距离,若不采用任何优化,为O(n2)。第4和第5步是常数时间。总的来说,复杂性为O(n2)。

    为什么算法得到的直线y=ax+b就是最优的一条直线?由于它是在两个最短距离的点(x1,y1)和(x2,y2)之间,这意味着两类数据之间不可能有更宽的通道。而由于它是“垂直平分线”,这意味着它做到了让“离它最近的数据点尽量远”。

    不过,细心的读者可能问到一个更加微妙的问题。那就是,记即数据点(? ? ? )和(? ? ? )之间线段的长度,那么两个数据点离上述直线的距离都是d/2。为什么样本数据中不可能有其他的数据点,离直线的距离小于d/2呢?这与凸包的性质有关,与直线是“垂直平分线”有关,也与算法中第3步强调了要包括“点与边的距离”有关。鼓励有兴趣的读者思考体会一下这背后的“玄机”。当然,也欢迎有疑问的读者与我们交流。

    本文形成过程中与陈道蓄教授有过深入讨论,他指出SVM算法中的第2步可以省去,进而第4步也可以省去了。想想的确如此。不过文中保留了原样,留作读者思考为什么那两步可以省去。

    释:①即平面上两个向量夹角的余弦值,其表达式的推导可见高中数学教材[2]。

    ②用什么特征来表示所关心的类型,与对应用背景的理解直接相关。这里采用两个特征分量只是用来说明算法运行的过程,实际应用中为区分随性和理性消费者会比这要求更多。

    ③这里,我们总假设两个凸多边形是不重叠的。

    参考文献:

    [1]汤晓鸥,陈玉琨.人工智能基础(高中版)[M].上海:华东师范大学出版社,2018,4:35.

    [2]薛彬,张淑梅.数学(第二册)[M].北京:人民教育出版社,2019,7:34.

    [3]陳道蓄.平面上的凸包计算[J].中国信息技术教育,2020(21):25-29.

    注:作者系北京大学计算机系原系主任。