深度学习框架的构建:从基本数据结构到算法

发表时间: 2022-06-16 17:53

我们要实现的元程序库要包含哪些内容呢?这个元程序库并不需要包含非常复杂的数据结构与算法,但应该具有足够的通用性,能够为我们的深度学习框架实现提供有力的支持。STL就是此类通用函数库中的一个典范:它包含的大部分数据结构与算法都比较简单,但被广泛地应用于各种C++程序的开发过程中。当然,C++标准模板库主要被应用于运行期,而我们要实现的元程序库则会在编译期大显身手。应用场景虽有所区别,但这并不妨碍我们借鉴STL的优秀设计。

1 数据结构的表示方法

STL中的主要数据结构可以划分为两类:顺序容器与关联容器。前者通过位置来访问数据,后者通过特定类型的键来访问数据。在运行期可以使用的工具相对较多,相应的数据表示形式也多种多样。以顺序容器为例,在STL中常用的顺序容器就包括vector、list等。这些数据结构各有优劣,用户可以根据具体场景进行选择。

相比之下,在编译期我们能使用的工具就不是那么多了:编译期所处理的是常量——无法修改数据的值将对我们的工具选择造成很大限制;编译期对指针等概念的支持相对较弱,我们也无法在编译期进行动态内存分配并以类似指针的形式保存分配的空间,用于后续访问。这些都限制了我们在构造数据结构时可以选择的工具。如第1章所讨论的那样,在编译期表示容器较方便的方法就是使用变长参数模板。我们会将其作为数据结构的载体,以表示在编译期使用的顺序容器与关联容器。

  • 顺序表:一个变长参数模板实例中的元素是天然有序的。按照C++的惯例,我们将变长参数模板中的元素按照从前到后的顺序赋予相应的索引值,索引值从0开始。比如对于tuple<int, double, char>来说,int、double、char所对应的索引值分别为012。
  • 集合:变长参数模板实例也可以表示集合。比如tuple<int, double, char>同样可以视为一个包含了3个元素的集合。集合中的元素没有顺序性,也即tuple<double, char,int> 所表示的集合与tuple<int, double, char>所表示的等价。另外,通常来说集合中的元素具有互异性,即相同的元素在集合中不会出现多次。因此,对于像tuple<int, char,int>这样的变长参数模板实例来说,是否可以将其视为集合呢?显然,这个实例中存在相同的元素。我们可以拒绝将其视为一个集合,也可以采用其他的方式来解释该实例,比如:无论容器中相同的元素出现多少次,都视为仅出现了一次。采用这种解释时,上述实例也可视为一个集合。要怎么解释容器中的元素是一个选择问题。我们将会在本章的后面讨论不同的选择,以及每种选择所带来的性能差异。
  • 映射:STL中的映射容器采用键-值对存储元素,可以通过键来获取相应的值,我们的元程序库中也将引入类似的构造。我们会使用KVBinder模板来存储键-值对。KVBinder的定义如下[2]:
1    template <typename TK, typename TV>2    struct KVBinder3    {4        using KeyType = TK;5        using ValueType = TV;6        static TV apply(TK*);7    };

KVBinder提供了元数据域来获取键与值的类型。在此基础上,我们可以使用变长参数模板容器来表示映射,比如tuple<KVBinder<int, int*>, KVBinder<char, char*>>——这个映射将一些类型与其指针类型关联了起来。

与集合类似,映射中的键有互异性,因此这里也存在是否将具有相同键的容器视为映射的问题。我们将会在讨论映射实现时分析不同选择所带来的性能差异。

  • 多重映射(multimap):STL提供了multimap来表示多重映射,也即键可以重复的映射。在我们的深度学习框架中,某些地方需要在编译期使用多重映射,因此我们的元程序库中也引入了多重映射。我们使用如下的结构来表示多重映射中的键值关系:
1    KVBinder<Key, ValueSequence<Values...>>

ValueSequence是一个变长参数模板,用于存储某个键所对应的值序列。变长参数模板同时还会作为多重映射的容器使用。一个典型的多重映射实例形如:

1    tuple<KVBinder<int, ValueSequence<char>>,2          KVBinder<double, ValueSequence<int, bool>>> 

它包含了3个键-值对:int-char、double-int与double-bool。

  • 数值容器:细心的读者可能发现了,前面所列出的容器中存储的元素都是类型。这是因为在我们将要实现的深度学习框架中,类型处理占据元程序的主要部分。除此之外,我们也会在某些地方用到与数值相关的元数据结构与算法。但它们与类型容器的处理方式非常相似,因此本章也就不详细讨论了。

可以看出,变长参数模板在我们的元程序库中占据了重要的地位,所有的元数据结构都是以它为载体来实现的。这种设计的缺点在于:给定一个变长参数模板容器,我们很难判断出它所表示的具体含义(序列、集合,还是映射……)。但它也有优点:容器的实例可以自由转换其角色,选择适当的算法。比如,映射可以看成集合(只需要将键-值对看成一个键),因此可以将集合相关的算法应用到映射上;集合又可以看成序列,因此可以将序列相关的算法应用到集合上。我们可以灵活地选择算法达到目的。

还有一点要说明的是:我们使用变长参数模板作为元数据结构的载体,但并不限制变长参数模板的具体类型。在前文中,我们使用了tuple作为示例,但我们也可以采用其他的变长参数模板。比如完全可以自定义一个变长参数模板容器,并使用它来表示序列、集合或映射。

以上就是我们所使用的元数据结构。在此基础上就可以引入一些算法来实现相关的操作了。让我们首先从一些简单的算法开始讨论。

2 基本算法

很多算法都是非常基础且易于实现的。比如获取顺序表尺寸(其中包含的元素个数)的算法:

1    template <typename TArray>2    struct Size_;34    template <template <typename...> class Cont, typename...T>5    struct Size_<Cont<T...>>6    {7        constexpr static size_t value = sizeof...(T);8    };910    template <typename TArray>11    constexpr static size_t Size = Size_<RemConstRef<TArray>>::value; 

这个算法的核心在第7行,它使用C++11中引入的关键字sizeof... 来获取一个类型序列的长度。我们基于这个关键字构造出了元函数Size_。注意,第1~2行是这个元函数模板的声明,而第4~8行是相应元函数的特化实现。正是这个特化实现限定了该元函数只能作用于变长参数模板容器。

在Size_元函数的基础上,我们引入了Size元函数。像第1章讨论的那样,调用Size元函数时,我们不再需要 ::value这样的依赖名称。同时,Size元函数还调用了RemConstRef对输入参数进行变换,使得元函数可以接收常量或引用类型。

RemConstRef的定义如下:

1    template <typename T>2    using RemConstRef = std::remove_cv_t<std::remove_reference_t<T>>;

其中调用了type_traits中的元函数,去掉了输入参数中的引用与常量限定符(如果有)。

因此,我们可以这样调用Size元函数:

1    using Cont = std::tuple<char, double, int>;2    constexpr size_t Res1 = Size<Cont>;3    constexpr size_t Res2 = Size<Cont&>;

其中Res1与Res2的值均为3。注意,Res2之所以能被求值,是因为RemConstRef去掉了输入参数中的引用限定符。

基本算法的另外两个例子是元函数Head与Tail,它们分别用于获取输入序列的首个元素与去除首个元素的子序列。与Size元函数类似,这两个元函数也分别调用了Head_与Tail_来实现各自的逻辑。Head_与Tail_的定义如下:

1     template <typename TSeqCont>2     struct Head_;34     template <template <typename...> class Container, typename TH,5                typename...TCases>6     struct Head_<Container<TH, TCases...>>7     {8         using type = TH;9     };1011    template <typename TSeqCont>12    struct Tail_;1314    template <template <typename...> class Container, typename TH,15             typename...TCases>16    struct Tail_<Container<TH, TCases...>> 17 {17    {18        using type = Container<TCases...>;19    };

类似算法的实现都非常直观。这里就不一一列举了。

3 算法的复杂度

理论上,使用第1章讨论的顺序、分支、循环代码的编写方法,我们可以实现大部分与容器相关的算法。但在实现其他算法之前,让我们首先以Size为例,分析其实现的复杂度。

读者可能会问:我们为什么要关心这些算法的复杂度?事实上,这些算法所对应的代码是在编译期被执行的,也就是说,它们的执行效率基本上不会对代码的运行期造成影响。既然如此,我们真的需要关心它们的实现复杂与否吗?

答案是肯定的。这里需要着重指出一点:即使是在编译期执行的代码,也是需要执行的。这些代码的执行者,实际上是编译器!

我们可以从另一个角度来审视代码的编译过程:我们的源程序就好似一段脚本,而编译器正如脚本的执行者,编译结果则类似脚本的执行结果。从这个角度上来说,编译一段C++代码的过程,与执行一段Python代码没有什么区别,都是需要占用系统资源与运行时间的。如果元函数的复杂度比较高,反复调用就会导致编译用时较长、编译所需内存较多。

另外,将编译的过程与一般脚本的执行过程进行类比并不完全公平。二者虽然有相似之处,但应用场景不同,它们面临的问题也不同。一般的脚本可能会被反复执行,处理的数据量可能较大(可能要以大量的数据作为输入并产生大量的输出),这就对脚本的执行速度产生了相对较高的要求。源代码文件相对较短,同时编译操作的执行频率相对较低(除了开发场景外,一般编译成功之后就不需要再次编译源代码文件了)。因此我们可以对编译器的执行效率有更大的容忍。

但编译器也有编译器的问题,正如我们在第1章所讨论的那样,编译器可能并没有针对元编程引入足够的优化。元函数在执行过程中所产生的实例可能都会保存在编译器的内存中,在整个编译过程中都不会被释放。因此,如果元函数的复杂度较高,可能导致编译器内存超限而编译失败。

对于老式的计算机或32位编译程序来说,这可能是个大问题(32位编译程序能够使用的最大内存容量为4GB,编译复杂的元程序很可能导致内存不足)。当前,主流的计算机是64位的,同时计算机中的内存容量也得到了很大的提高,这能在一定程度上缓解内存不足的问题。但我们依旧需要关注元函数的复杂度,以防在元函数过于复杂、编译项目较大的情况下,编译用时较长或占用内存较多而导致编译失败。

那么,我们要如何衡量元函数的复杂度呢?作为一个普通的C++ 开发者,我们可能对编译器内部的实现原理并不清楚,因此无法做出很精确的估计。但我们至少可以估计出在一个元函数的执行过程中,编译器可能会构造出的实例数,并以此作为元函数复杂度的一种度量:当然,我们希望元函数执行过程中所构造出的实例数越少越好,实例数越多,说明算法越复杂。

让我们回顾一下之前讨论的Size,对于以下的语句:

1    Size<tuple<double, int, char>>

编译器会在执行过程中接收并产生如下的实例:

1    tuple<double, int, char>2    RemConsRef<tuple<double, int, char>>3    Size_<RemConsRef<tuple<double, int, char>>>4    Size_<RemConsRef<tuple<double, int, char>>>::value5    Size<tuple<double, int, char>>

这些实例可能会被一一构造出来并保存在编译器的内存中。不同的实例对应的构造与存储成本并不相同[3]。但我们在这里并不会考虑这种成本差异的细节,只是对算法的复杂度进行粗略的估计。

现在让我们来看一个相对复杂的算法:数组索引,即给定一个数组,获取其中的第N个元素。

读者可能会感到诧异:这是复杂的算法吗?事实上,可能出乎读者的意料,这可能是我们将要实现的最复杂的算法之一了。对运行期数组进行索引非常简单,这是因为从硬件到软件层面上都对其提供了很好的支持。但在编译期,语言规范对这种操作并没有提供足够的支持,这就可能导致相应算法(或者说相应操作)的复杂度非常高。

让我们首先实现一个基础版本,再来分析一下这个版本的复杂度高在何处。利用第1章讨论的顺序、分支、循环代码的编写方法,我们可以相对容易地实现数组索引,算法如下:

1     template <typename TCont, size_t ID>2     struct At_;34     template <template<typename...> class TCont,5               typename TCurType, typename... TTypes, size_t ID>6     struct At_<TCont<TCurType, TTypes...>, ID>7     {8         using type = typename At_<TCont<TTypes...>, ID-1>::type;9     };1011    template <template<typename...> class TCont,12              typename TCurType, typename... TTypes>13    struct At_<TCont<TCurType, TTypes...>, 0>14    {15        using type = TCurType;16    };

At_元函数的实现包含了一个声明与两个模板特化。第1~2行的声明表明该元函数接收两个参数,分别对应输入序列与索引值。后两个特化则形成了一个循环逻辑:第一个特化用于匹配索引值不为0的情况——此时系统会将索引值减1,继续下一步循环;第二个特化匹配索引值为0的情况,此时返回当前类型。这个元函数的使用方式很简单,比如typename At_<tuple<double, int, char>, 2>::type的结果为char。

现在让我们粗略地估计一下该元函数的复杂度。以typename At_<tuple<double, int, char >, 2>::type为例,看一下元函数在执行过程中可能产生的实例个数。不难看出,此时编译器会产生如下的一些实例:

1    At_<tuple<double, int, char>, 2>2    At_<tuple<int, char>, 1>3    At_<tuple<char>, 0>

读者可能意识到了:编译器所产生的实例个数与输入的索引值成正比。这并不是一个好现象。显然,当输入的索引值比较大时,编译器就会产生大量的实例,这同时意味着更长的编译时间,以及更多的内存占用。

事实上,这种实现还存在另一个问题。通常来说,如果将信息保存成一个数组,那么我们往往需要访问数组不同位置处的元素。考虑tuple<double, int, char> 这个数组,在刚刚获取了索引值为2的元素之后,如果我们希望再次调用该元函数获取索引值为1的元素,那么编译器会产生如下的实例:

1    At_<tuple<double, int, char>, 1>2    At_<tuple<int, char>, 0>

读者可能已经发现了,这些实例化的结果与之前实例化的结果完全不同!这就意味着虽然编译器可能在内存中保存了之前的实例化结果,但我们无法从之前的实例化结果中获益。进一步,编译器可能会将这些新的实例保存在内存中,进一步增加编译负担。

希望这个示例能让读者体会到一个实现相对较差的元函数可能对编译器产生的不良影响。一个好的元函数实现应该使得实例化的次数尽量少,同时能尽量地复用之前实例化的结果。如果我们仅仅采用第1章所学习的顺序、分支与循环代码的编写方法,显然无法达到这个目的。要想降低元函数的复杂度,就需要求助于一些特别的技巧。我们将在本章的后续部分讨论一些降低复杂度的技巧。同时,我们将在本章的结尾给出一个低复杂度的序列索引算法实现,但本着从易到难的原则,我们将首先讨论一些相对容易掌握的技巧。首先,让我们来看第一类技巧:基于包展开与折叠表达式的优化。

本文摘自:《动手打造深度学习框架》

本书基于C++编写,旨在带领读者动手打造出一个深度学习框架。本书首先介绍C++模板元编程的基础技术,然后在此基础上剖析深度学习框架的内部结构,逐一实现深度学习框架中的各个组件和功能,包括基本数据结构、运算与表达模板、基本层、复合层、循环层、求值与优化等,最终打造出一个深度学习框架。本书将深度学习框架与C++模板元编程有机结合,更利于读者学习和掌握使用C++开发大型项目的方法。

本书适合对C++有一定了解,希望深入了解深度学习框架内部实现细节,以及提升C++程序设计水平的读者阅读。