导弹拦截系统升级!

去年就接触了这道经典得不能再经典的“导弹拦截”问题,当时理解了DP算法,自我感觉良好,不知O(n^2)算法还不是其最高境界。后来对这道题的O(nlogn)算法也有听说,但是今天才终于正式开始学习这种神奇的算法了。

其实这种算法也不十分神秘。在这之前,先回顾一下原来的O(n^2)算法。这道题哪里都找得到。

O(n^2)算法

以前做的代码:

#include <iostream>
using namespace std;
int main(){
    int N,missile[15],amount[15];
    cin>>N>>missile[0];
    amount[0]=1;
    for(int i=1;i<N;i++){
        cin>>missile[i];
        int max=0;
        amount[i]=1;
        for(int j=0;j<i;j++){
            if(missile[j]>=missile[i]&&amount[j]>max){
                max=amount[j];
            }
        }
        amount[i]+=max;
    }
    int max=0;
    for(int i=0;i<N;i++){
        if(amount[i]>max){
            max=amount[i];
        }
    }
    cout<<max;
}

大概思路就是说用amount[i]数组来储存包括第i个导弹的最长不上升子序列的长度,然后对于后面每一颗导弹,寻找在它之前的比它高的已经拦截了最多导弹的系统,这样就能维持amount的性质。最后全部扫一遍找amount最大值即序列。如果要输出序列,可以另外用一个数组记录每次的选择从而从最优解中倒推回去。

当时NOIP1999的时候n1000n \leq 1000O(n2)O(n^2)算法完全不虚。但是现在我看到在CodeVS上的数据量已经达到5000了,可能还是会出问题的。很容易看出,外面的那一层for循环一定是必要的,因为你总得把数据输入进来。那寻找符合条件的导弹的时候效率会不会有点低呢?在这里是否可以优化?

优化升级加速

我们使用数组来储存每个导弹的序列长度,搜索的时候寻找比自己高的导弹的最长序列长度,那不是最长的是不是可以舍弃掉呢?不够高的是不是可以舍弃掉呢?不妨换个角度来思考问题,我们可以使用数组来储存每个序列长度指向的导弹。

设b[k]为长度为k的最长不下降子序列的最后一个导弹的最大高度。这里说最大高度,是因为可能有多个子序列满足这个高度,而高度越大,可重复利用的可能就越大,这里可以稍微贪心一下,说这样一定是更好的策略。

现在我们需要证明几个问题:

  • b[]的定义域是连续整数,且从0开始
  • b[]是单调的
  • 开始考虑一个新的导弹的时候,只需要更新比b[]中恰好大于这个导弹高度的元素的后一个

可以先思考一下这些说法是否都正确。之后我们可以十分愉快的O(logn)O(\log n)查找需要的b[k],然后令b[k+1]=max{b[k+1],height},使用类似于栈的方法,O(n)O(n)处理完所有的导弹,就能用O(nlogn)O(n\log n)的算法得到最长不上升子序列的长度了。

证明

b[]的定义域为从0开始的连续整数

这个非常好证明啊!如果你有一个长度为k的最长不上升子序列,那么删掉最后一个元素,就会得到一个长度为k1k-1的最长不上升子序列。那么至少也有一个长度为k1k-1的最长不上升子序列了吧?对于0,已经没有办法再删除元素了。满足条件的b[]靠具体程序来保证。

b[]是单调的

这个想想就觉得是正确的,不过我们可以反证一下。在这个问题中,b[]单调递减,所以我们假设b[k]只需要更新一个元素就能维护性质

只需要更新比b[]中恰好大于这个导弹高度的元素的后一个

假设b[k]为恰好比当前考虑导弹高的最长不上升子序列的高度,b[k]之后的元素是不需要讨论的,因为b[k]之后的元素都比b[k]小,也就是比height小,无法拦截。对于b[k]前面的元素,一定是大于等于b[k]的,如果更新了前面的元素,那么只能往下更新,也就破坏了性质了。所以只要更新一个元素就能维护性质。

我的代码

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
int b[5010],top=0;
int findk(int l,int r,int h){
    if(l==r){
        return l;
    }
    int mid=(l+r>>1)+1;
    if(b[mid]<h){
        return findk(l,mid-1,h);
    }else if(b[mid]>h){
        return findk(mid,r,h);
    }else{
        return mid;
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n,h,k;
    memset(b,-1,sizeof(b));
    cin>>n;
    for(int i=0;i<n;i++){
        cin>>h;
        k=findk(0,top,h);
        if(b[k+1]==-1){
            top++;
        }
        b[k+1]=max(b[k+1],h);
    }
    cout<<top;
}

利用NOI官方题库测试了一下,AC了。只是没能测试到大数据时的性能。

之前听说这种写法很麻烦,但我觉得代码量也不怎么大,难道说我写错了?因为网上还有好多写法和我不太一样……

avatar
Kerry Su