天天看点

STL实现细节之rotate()

引言

STL中 rotate(first, middle, last) 函数的作用是原地把容器区间 [first, middle)(左半部分) 与 [middle, last) (右半部分)的元素互换。

它的实现充分利用了不同迭代器的特性进行算法优化,从而达到最优的性能。以下是libc++(该库用于clang的C++编译器中)中该函数的实现,为了可读性修改了部分变量名以及删除了类型检查:

代码全文

template <class _ForwardIterator>
_ForwardIterator
    rotate_left(_ForwardIterator first, _ForwardIterator last)
{
    typedef typename iterator_traits<_ForwardIterator>::value_type value_type;
    value_type tmp = move(*first);
    _ForwardIterator lm1 = move(next(first), last, first);
    *lm1 = move(tmp);
    return lm1;
}

template <class _BidirectionalIterator>
_BidirectionalIterator
    rotate_right(_BidirectionalIterator first, _BidirectionalIterator last)
{
    typedef typename iterator_traits<_BidirectionalIterator>::value_type value_type;
    _BidirectionalIterator lm1 = prev(last);
    value_type tmp = move(*lm1);
    _BidirectionalIterator fp1 = move_backward(first, lm1, last);
    *first = move(tmp);
    return fp1;
}

template <class _ForwardIterator>
_ForwardIterator
    rotate_forward(_ForwardIterator first, _ForwardIterator middle, _ForwardIterator last)
{
    _ForwardIterator i = middle;
    while (true)
    {
        swap(*first, *i);
        ++first;
        if (++i == last)
            break;
        if (first == middle)
            middle = i;
    }
    _ForwardIterator r = first;
    if (first != middle)
    {
        i = middle;
        while (true)
        {
            swap(*first, *i);
            ++first;
            if (++i == last)
            {
                if (first == middle)
                    break;
                i = middle;
            }
            else if (first == middle)
                middle = i;
        }
    }
    return r;
}

template<typename _Integral>
_Integral
    algo_gcd(_Integral x, _Integral y)
{
    do
    {
        _Integral t = x % y;
        x = y;
        y = t;
    } while (y);
    return x;
}

template<typename _RandomAccessIterator>
_RandomAccessIterator
    rotate_gcd(_RandomAccessIterator first, _RandomAccessIterator middle, _RandomAccessIterator last)
{
    typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
    typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;

    const difference_type m1 = middle - first;
    const difference_type m2 = last - middle;
    if (m1 == m2)
    {
        swap_ranges(first, middle, middle);
        return middle;
    }
    const difference_type g = algo_gcd(m1, m2);
    for (_RandomAccessIterator p = first + g; p != first;)
    {
        value_type t(move(*--p));
        _RandomAccessIterator p1 = p;
        _RandomAccessIterator p2 = p1 + m1;
        do
        {
            *p1 = move(*p2);
            p1 = p2;
            const difference_type d = last - p2;
            if (m1 < d)
                p2 += m1;
            else
                p2 = first + (m1 - d);
        } while (p2 != p);
        *p1 = move(t);
    }
    return first + m2;
}

template <class _ForwardIterator>
_ForwardIterator
    __rotate(_ForwardIterator first, _ForwardIterator middle, _ForwardIterator last,
        forward_iterator_tag)
{
    if (next(first) == middle)
        return rotate_left(first, last);
    return rotate_forward(first, middle, last);
}

template <class _BidirectionalIterator>
_BidirectionalIterator
    __rotate(_BidirectionalIterator first, _BidirectionalIterator middle, _BidirectionalIterator last,
        bidirectional_iterator_tag)
{
    if (next(first) == middle)
        return rotate_left(first, last);
    if (next(middle) == last)
        return rotate_right(first, last);
    return rotate_forward(first, middle, last);
}

template <class _RandomAccessIterator>
_RandomAccessIterator
    __rotate(_RandomAccessIterator first, _RandomAccessIterator middle, _RandomAccessIterator last,
        random_access_iterator_tag)
{
    if (next(first) == middle)
        return rotate_left(first, last);
    if (next(middle) == last)
        return rotate_right(first, last);
    return rotate_gcd(first, middle, last);
}

template <class _ForwardIterator>
_ForwardIterator
    rotate(_ForwardIterator first, _ForwardIterator middle, _ForwardIterator last)
{
    if (first == middle)
        return last;
    if (middle == last)
        return first;
    return __rotate(first, middle, last,
        typename iterator_traits<_ForwardIterator>::iterator_category());
}
           

代码结构

这个小小的功能代码就有一百五十多行,是不是有种太长不看的想法?没关系,我们先来捋一捋函数的调用关系。

STL实现细节之rotate()

首先,可以看到外部调用函数

rotate()

时,会根据传入的迭代器类型自动选择相应的函数版本。总共有三种迭代器的重载版本

  • 前向
  • 双向
  • 随机访问

rotate_left()

函数作用是把容器所有元素向左(起始端)循环移动一位

rotate_right()

则相反,把容器所有元素向右(末端)循环移动一位

如果不考虑对上述两个函数的调用,那么前向迭代版本和双向迭代版本是一样的,最终都是调用

rotate_forward()

最关键的不同在于

rotate_forward()

rotate_gcd()

的实现上。也就是专属于随机访问迭代版本的优化。

前向迭代版

先看看

rotate_forward()

的实现。

template <class _ForwardIterator>
_ForwardIterator
    rotate_forward(_ForwardIterator first, _ForwardIterator middle, _ForwardIterator last)
{
    _ForwardIterator i = middle;
    while (true)
    {
        swap(*first, *i);
        ++first;
        if (++i == last)
            break;
        if (first == middle)
            middle = i;
    }
    _ForwardIterator r = first;
    if (first != middle)
    {
        i = middle;
        while (true)
        {
            swap(*first, *i);
            ++first;
            if (++i == last)
            {
                if (first == middle)
                    break;
                i = middle;
            }
            else if (first == middle)
                middle = i;
        }
    }
    return r;
}
           

观察代码,发现代码分成两部分,而且两部分非常相似。第一部分代码的目的在于获取返回值,如果不需要返回值,完全可以写成下面这样子的:

template <class _ForwardIterator>
void
    rotate_forward(_ForwardIterator first, _ForwardIterator middle, _ForwardIterator last)
{
    _ForwardIterator i = middle;
    while (true)
    {
        swap(*first, *i);
        ++first;
        if (++i == last)
        {
            if (first == middle)
                break;
            i = middle;
        }
        else if (first == middle)
            middle = i;
    }
}
           

下面直接用图来说明吧。

假设容器内元素排布分为ABC三段。

总共分两种情况讨论:

  1. 图中

    middle

    指向B段第一个元素,A与B段等长,目标是把A段挪到最后,形成BCA的结构。这种情况下

    first

    首先到达

    middle

    ,这时原来处于B段的元素已经被移动到正确的位置上,然后调整

    middle

    的位置到i的位置,然后继续对AC段进行rotate操作。
    STL实现细节之rotate()
  2. middle

    指向C段第一个元素,A与C段等长,目标是把C段挪到最前面,形成CAB的结构。这种情况下

    i

    首先到达末尾,C段元素已经就位,调整

    i

    的位置到

    middle

    ,继续对BA段进行操作。
    STL实现细节之rotate()

很明显,这个算法虽然时间复杂度是线性的,但是有部分元素需要多次移动才能达到最终位置上。

随机访问版

接下来看看针对随机访问迭代器的优化版本

template<typename _RandomAccessIterator>
_RandomAccessIterator
    rotate_gcd(_RandomAccessIterator first, _RandomAccessIterator middle, _RandomAccessIterator last)
{
    typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
    typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;

    const difference_type m1 = middle - first;
    const difference_type m2 = last - middle;
    if (m1 == m2)
    {
        swap_ranges(first, middle, middle);
        return middle;
    }
    const difference_type g = algo_gcd(m1, m2);
    for (_RandomAccessIterator p = first + g; p != first;)
    {
        value_type t(move(*--p));
        _RandomAccessIterator p1 = p;
        _RandomAccessIterator p2 = p1 + m1;
        do
        {
            *p1 = move(*p2);
            p1 = p2;
            const difference_type d = last - p2;
            if (m1 < d)
                p2 += m1;
            else
                p2 = first + (m1 - d);
        } while (p2 != p);
        *p1 = move(t);
    }
    return first + m2;
}
           

这个算法的理解就需要花点心思了。还是从示例讲起吧。下面是一个长度为10的数组a,假设我们对它调用

rotate(a, a + 4, a + 10)

,把前四个和后六个元素调换位置。

也就是从这样:

位置 1 2 3 4 5 6 7 8 9
元素值 1 2 3 4 5 6 7 8 9 10

表1

变成这样:

位置 1 2 3 4 5 6 7 8 9
元素值 5 6 7 8 9 10 1 2 3 4

表2

首先

middle

将数组分隔成两段,左段长度为4,右段长度为6,那么他们的最大公约数为

g=2

,那么我们要进行两次循环移位才做。

从位置

p=g-1=1

开始,每隔4个元素往前移4位。

设临时变量

t

,然后以

t<-1<-5<-9<-3<-7<-t

的赋值顺序循环移位。然后变成这样:

位置 1 2 3 4 5 6 7 8 9
元素值 1 6 3 8 5 10 7 2 9 4

表3

完成一次循环移位后,继续

p=p-1=0

,以下面的赋值顺序进行循环移位:

t<-0<-4<-8<-2<-6<-t

完成这次循环移位后,满足

p==0

,算法结束,得到表2的最终结果。

可以发现,每个元素都只被移动了一次,O(n)中的常数应该更小,那么理论上来说它应该比

rotate_forward()

更快。

测试

测试环境:

i5-3210M+8G RAM

Windows 10 64bit

Visual Studio 2015 with Update3

Build target: Release x64

测试代码如下:

auto get_duration(const system_clock::time_point &start, const system_clock::time_point &fin)
{
    auto duration = duration_cast<microseconds>(fin - start);
    return double(duration.count()) * microseconds::period::num / microseconds::period::den;
}

int main()
{
    auto n = ;
    auto data_len = ;
    auto middle = ;
    auto avg_a = , avg_b = ;
    vector<int> data_a(data_len);
    vector<int> data_b(data_len);
    iota(begin(data_a), end(data_a), );
    iota(begin(data_b), end(data_b), );
    while (--n)
    {
        auto start = system_clock::now();
        rotate_gcd(begin(data_b), begin(data_b) + middle, end(data_b));
        auto finish = system_clock::now();
        avg_a += get_duration(start, finish);
        cout << duration << " ";
        start = system_clock::now();
        rotate_forward(begin(data_a), begin(data_a) + middle, end(data_a));
        finish = system_clock::now();
        avg_b += get_duration(start, finish);
        cout << duration << "\n";
    }
    avg_a /= ;
    avg_b /= ;
    cout << avg_a << " " << avg_b << " " << avg_a / avg_b << "\n";
    return ;
}
           

得到测试数据如下:

middle=19

次数 rotate_gcd()(单位:秒) rotate_forward()(单位:秒)
第1次 0.193052 0.034112
第2次 0.171875 0.0369
第3次 0.20115 0.039593
第4次 0.202003 0.043451
第5次 0.189696 0.030061
第6次 0.185225 0.028943
第7次 0.192203 0.025791
第8次 0.155145 0.021634
第9次 0.159483 0.022387
第10次 0.160719 0.022697
平均 0.181055 0.0305569

表4

rotate_gcd()

耗时居然是

rotate_forward()

的6倍!

看上去更复杂巧妙的算法居然比简单的算法还慢。

进一步对耗时数据进行采样(数据规模从两千万缩小到两百万),可以得到下图(纵坐标是耗时,横坐标是

middle

的值,蓝线代表

rotate_gcd()

,绿线代表

rotate_forward()

):

STL实现细节之rotate()

可以看到

rotate_gcd()

耗时是

rotate_forward()

数倍以上,且性能曲线不平稳,实在跟它复杂的原理不匹配。而这个算法在MSVC内建的STL代码中已经移除掉了,大概就是因为这个原因吧。

总结

为什么会导致这样的结果呢?CPU缓存是很大的影响因素。

rotate_forward()

里,可以看到它对元素的访问总是从位置1到位置2再到位置3这样依次访问,CPU可以提前将后面的元素一并加载到缓存中了,免去频繁读取内存的延迟。

rotate_gcd()

对缓存就没那么友好了,它总是每隔一段距离再读取下一个元素,这种做法容易造成cache miss,必须从内存中读取元素,拖慢速度。

因此,在设计算法的时候,不仅仅要考虑算法自身的速度,也要考虑平台相关的优化,比如尽量做到缓存友好,必要时可以使用SIMD等。

参考文献

http://apprize.info/programming/mathematics/11.html

继续阅读