Stl Upper_bound函数实现

写了一个upper_bound的实现。其中递归使用二分法求解最上界,虽然写的完全不像STL的风格,但是练手还是可以的。

#include #include #include #include
using namespace std; int UpperBound(int* a, int start, int end , const int&
value){ int mid = 0; int index = 0; int up_index = 0; if(a[end]<value) //
特判比最后一个数大时,end+1 return end+1; //这里是经典的二分 while(start<= end){ mid = start +
(end - start)/2; if(a[mid] == value){ index = mid + 1; // 去递归的找一下上界 up_index =
UpperBound(a,mid+1,end,value); // 如果找到的上界比现在的还小,那么就用现在的 return up_index >
index? up_index : index; } else if(a[mid]<value){ start = mid+1; } else{ end =
mid-1; // 记录上界 index = mid; } } return index; }

如果原数组中没有存在那个元素,就根本没有调用那个递归程序,递归只有在出现多个此元素时才会调用。另外中间递归调用段地方还可以改写为:

if(a[mid] == value){ index = mid + 1; /* up_index =
UpperBound(a,mid+1,end,value); return up_index > index? up_index : index; */
// 因为只有在递归调用中start>end时,才会返回一个index = 0 return (mid+1>end) ? index :
UpperBound(a,mid+1,end,value); }

这样写完后写一下测试代码,顺便wrap一层upper_bound:

int upper_bound(int * a,int n, const int& value){ // 使用宏可以在编译时去除 #ifdef TEST
// pre-condition assert(NULL != a && n>0); // check the array a is sorted by
elements for(int i = 1; i < n; i++){ assert(a[i-1]<=a[i]); } int * tmp_a = new
int[n]; memcpy(tmp_a,a,sizeof(int)*n); #endif int index =
UpperBound(a,0,n-1,value); #ifdef TEST // post-condition //check whether the
array is changed or not for(int i = 0; i < n; i++) assert(a[i] == tmp_a[i]);
if(index == 0){ // min assert(a[index]>value); } else if(index == n){ // max
assert(a[index-1]<=value); } else{ assert(a[index]>value &&
a[index-1]<=value); } delete [] tmp_a; #endif return index; }

如果希望别人不调用UpperBound而只调用upper_bound,可以使用static 关键字, 将UpperBound限制在只在本文件中使用。别人调用
就只能调用upper_bound()函数。不过STL的教学源码比我这精简的多,根本无须使用递归。真正的STL源码变量名称会使用下划线__作为起始。

template <class ForwardIterator, class T> ForwardIterator upper_bound (
ForwardIterator first, ForwardIterator last, const T& value ) {
ForwardIterator it; iterator_traits::distance_type count,
step; count = distance(first,last); while (count>0) { it = first;
step=count/2; advance (it,step); if (!(value<*it)) // or: if
(!comp(value,*it)), for the comp version { first=++it; count-=step+1; } else
count=step; } return first; }

这里的advance()函数定义类似:

template<class Iterator, class Distance> void advance(Iterator& i, Distance
n); 类似于 i = x.begin()+n; i is an iterator of x

最后把主函数贴上做结:

int main(int argc, char* argv[]) { // 这里你可以使用random的方法进行测试。 int x[] =
{1,2,3,4,5,5,5,5,9}; vector v(x,x+9); sort(v.begin(),v.end()); cout <<
(int)(upper_bound(v.begin(),v.end(),9)-v.begin()); cout << upper_bound(x,9,9);
return 0; }