终于下定决心开始学主席树了,先找了一道不带修改的区间第K大做。关于带修改的区间第K大,参见我的下一篇博客 {% post_link zoj2112 %}

主席树其实就是一种可持久化的线段树,是函数式编程的一个很经典的应用(另外的话,我就只知道FANHQ_TREAP了)。主要思想是在对线段树的节点进行修改的时候,不改变原节点,而是新开一个节点并链到数中。

用主席树对一个长度为n的数列做区间第K大需要建n棵线段树,第i棵线段树维护前i个数的信息。线段树对数列的值二分,一个覆盖[l, r]的节点表示前i个数中值的大小在[l, r]内的数的个数。如数列1,2,3,5,4的第4棵线段树记录1,2,3,5的信息,这棵线段树上覆盖[1, 5]的节点的tot=4,因为所有数的大小都在[1, 5]内,而覆盖[1, 4]的节点的tot=3,因为只有1,2,3在[1, 4]内。

那我们怎么找区间第k大呢?假设我们需要知道区间[l, r]内(这里的l和r表示数的下标,不同于上面表示值)的区间第k大,那我们需要一棵记录[l, r]内数的信息的线段树,我们之前已经建好了数列所有前缀的线段树,稍加思考就可以知道,这样一棵线段树可以由第r棵线段树减去第l - 1棵线段树得到。比如我要得到[3, 5]内数的大小在[1, 5]之间的数的个数,那我们可以通过前5个数中大小在[1, 5]之间的数的个数减去前2个数中大小在[1, 5]之间的数的个数得到。

好像做了那么久都没有用到可持久化啊,这样的空间复杂度和时间复杂度依然在(O(n^2))级别,我还不如写暴力。那怎么办呢?前面我们讲到了函数式编程的方法。首先,由于主席树是根据值做二分,每棵主席树其实长得都一样,而且第i棵主席树和第i - 1棵主席树许多节点也是一样的,要修改的只是长度为(log_2n)的一条链,那我们在生成第i棵主席树的时候就可以用到第i棵主席树上大多数的节点,对于需要修改的只需要新建节点即可。这样的话时间和空间复杂度就降到了(O(nlog_2n))的级别。

再多加考虑的话会发现很多主席数上的节点都记录着0,而且一个节点记录着0意味着它的两棵子树上的所有节点记录的也都是0,那么这些节点就是多余的。所以初始的时候我们只需要一个空节点,在初始化的过程中再往树上增加节点就又可以很大程度上节约内存。

如果对原数列进行离散化可以进一步节约空间和时间,这样做下来,在POJ上这道题我一共开了200W的线段树节点就AC了, 内存40+M,时间1.5S(写得渣求不BS)。下面是代码(最后有一个调试的时候用的暴力,请无视之)。

//POJ 2104.cpp
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<string>
#include<cmath>
#include<queue>
#include<stack>
#include<set>
#include<map>

//#define LOCAL
//#define READ_FILE
#define READ_FREAD
//#define VISUAL_STUDIO

#ifdef LOCAL
#define LOCAL_TIME
#define LOCAL_DEBUG
//#define LOCAL_SIMP
#endif

const int MAXN = 100010;
const int MAXTREE = 2000000;

#ifdef VISUAL_STUDIO
#pragma warning(disable: 4996)
#endif

#ifdef LOCAL_TIME
#include<ctime>
#endif

#ifdef READ_FREAD
char fread_char;

inline void fread_init()
{
	fread_char = getchar();
}

inline int get_int()
{
	int ret = 0;
	while ((fread_char < '0') || (fread_char > '9'))
	{
		fread_char = getchar();
	}
	while ((fread_char >= '0') && (fread_char <= '9'))
	{
		ret = ret * 10 + fread_char - '0';
		fread_char = getchar();
	}
	return ret;
}
#endif

inline void read(int &a)
{
#ifdef READ_FREAD
	a = get_int();
#else
	scanf("%d", &a);
#endif
}

struct discre_node
{
	int val, id;
	inline bool operator < (const discre_node &b) const
	{
		return val < b.val;
	}
};

struct tree_node
{
	int tot;
	tree_node *l_child, *r_child;
	inline tree_node *init()
	{
		l_child = r_child = NULL;
		tot = 0;
		return this;
	}
};

tree_node recover_tree_node[MAXTREE];
tree_node *stack_recover_tree_node[MAXTREE];
int top_stack_recover_treee_node = MAXTREE;

inline void init_recover_tree_node()
{
	for (int i = 0; i < MAXTREE; i++)
	{
		stack_recover_tree_node[i] = recover_tree_node + i;
	}
}

inline tree_node *new_node()
{
	return stack_recover_tree_node[--top_stack_recover_treee_node]->init();
}

int n, m;
discre_node discre_tmp[MAXN];
int discre_to_num[MAXN];
int discre[MAXN];
int max_dis;
int l, r, k;
tree_node *prefix[MAXN];

inline void tree_insert(const int i)
{
	prefix[i] = new_node();
	tree_node *last_root = (i == 0 ? NULL : prefix[i - 1]);
	tree_node *new_root = prefix[i];
	int l = 1, r = max_dis, mid;
	while (l < r)
	{
		mid = (l + r) >> 1;
		if (discre[i] <= mid)
		{
			new_root->r_child = (last_root == NULL ? NULL : last_root->r_child);
			if (last_root != NULL)
			{
				new_root->tot = last_root->tot + 1;
				last_root = last_root->l_child;
			}
			else
			{
				new_root->tot = 1;
			}
			new_root->l_child = new_node();
			new_root = new_root->l_child;
			r = mid;
		}
		else
		{
			new_root->l_child = (last_root == NULL ? NULL : last_root->l_child);
			if (last_root != NULL)
			{
				new_root->tot = last_root->tot + 1;
				last_root = last_root->r_child;
			}
			else
			{
				new_root->tot = 1;
			}
			new_root->r_child = new_node();
			new_root = new_root->r_child;
			l = mid + 1;
		}
	}
	if (last_root != NULL)
	{
		new_root->tot = last_root->tot + 1;
	}
	else
	{
		new_root->tot = 1;
	}
}

inline int left_size(tree_node *query)
{
	if (query == NULL)
	{
		return 0;
	}
	if (query->l_child == NULL)
	{
		return 0;
	}
	return query->l_child->tot;
}

inline int tree_k_th(const int l, const int r, const int k)
{
	tree_node *left_root = (l == 0 ? NULL : prefix[l - 1]);
	tree_node *right_root = prefix[r];
	int binary_l = 1, binary_r = max_dis, mid;
	int tot_left = 0;
	while (binary_l < binary_r)
	{
		mid = (binary_l + binary_r) >> 1;
		if (tot_left + left_size(right_root) - left_size(left_root) >= k)
		{
			right_root = right_root->l_child;
			if (left_root != NULL)
			{
				left_root = left_root->l_child;
			}
			binary_r = mid;
		}
		else
		{
			tot_left += left_size(right_root) - left_size(left_root);
			right_root = right_root->r_child;
			if (left_root != NULL)
			{
				left_root = left_root->r_child;
			}
			binary_l = mid + 1;
		}
	}
	return binary_l;
}

#ifndef LOCAL_SIMP
int main()
{
#ifdef LOCAL_TIME
	long long start_time_ = clock();
#endif
#ifdef READ_FILE
	freopen("2104.in", "r", stdin);
	freopen("2104.out", "w", stdout);
#endif
#ifdef READ_FREAD
	fread_init();
#endif
	read(n);
	read(m);
	init_recover_tree_node();
	for (int i = 0; i < n; i++)
	{
		scanf("%d", &discre_tmp[i].val);
		discre_tmp[i].id = i;
	}
	std::sort(discre_tmp, discre_tmp + n);
	max_dis = 0;
	for (int i = 0; i < n; i++)
	{
		if ((i == 0) || (discre_tmp[i - 1] < discre_tmp[i]))
		{
			discre_to_num[++max_dis] = discre_tmp[i].val;
		}
		discre[discre_tmp[i].id] = max_dis;
	}
	for (int i = 0; i < n; i++)
	{
		tree_insert(i);
	}
	for (int i = 0; i < m; i++)
	{
		read(l);
		read(r);
		read(k);
		printf("%dn", discre_to_num[tree_k_th(l - 1, r - 1, k)]);
	}
#ifdef LOCAL_TIME
	printf("run time: %lld msn", clock() - start_time_);
#endif
#ifdef READ_FILE
	fclose(stdin);
	fclose(stdout);
#endif
	return 0;
}
#else
int a[MAXN];
int sort_tmp[MAXN];

int main()
{
	const int N = 10;
	const int M = 10;
	const int RAND_BASE = 243;
	freopen("2104.in", "w", stdout);
	srand(RAND_BASE);
	printf("%d %dn", N, M);
	for (int i = 0; i < N; i++)
	{
		printf("%d ", rand());
	}
	printf("n");
	for (int i = 0; i < M; i++)
	{
		int l = rand() % N + 1;
		int r = l + rand() % (N - l + 1);
		int k = rand() % (r - l + 1) + 1;
		printf("%d %d %dn", l, r, k);
	}
	fclose(stdout);
	freopen("2104.in", "r", stdin);
	freopen("2104std.out", "w", stdout);
	scanf("%d %d", &n, &m);
	for (int i = 0; i < n; i++)
	{
		scanf("%d", a + i);
	}
	for (int i = 0; i < m; i++)
	{
		scanf("%d %d %d", &l, &r, &k);
		memcpy(sort_tmp, a + l - 1, (r - l + 1) * sizeof(int));
		std::sort(sort_tmp, sort_tmp + r - l + 1);
		printf("%dn", sort_tmp[k - 1]);
	}
	fclose(stdin);
	fclose(stdout);
	return 0;
}
#endif