TopCoder Member SRM501 Div1Hard

問題

遺跡において宝石を集めたい。行える行動は、「1マス深く潜る(timeYの時間がかかる)」と、「1マス横に移動する(timeXの時間がかかる)」である。横への移動は高々LR回しか行えない。宝石には価値Viが付いていて、集めた宝石の価値の総和をgoalValue以上にしたい。そのとき、かかる時間の最小値を求めよ。

解法

宝石の位置と、横移動の回数をキーにDP。当然宝石は浅い方から順に見ていく。移動の前後で、座標をx、横移動の回数をiとすると、「x+i」もしくは「x-i」は一定である。横移動の回数は単調増加であるから、Segment Treeを用いると価値の最大値を求められる。
同じ深さに複数の宝石があり得る。「そこより浅い宝石から直接到達する場合の最善値」を求める→その深さで左から順になめる→右から順になめる→左からと右からをあわせる という方法で求めることができる。下のコードでは、同じ位置に宝石が複数あるとまずいので最初に同じ位置の宝石をまとめている。

コード

大変長くて読みにくくて汚いコード

#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

const long long minf_l = -(1LL<<62LL);
const int minf = 1<<31;

class seg_tree
{
	int data[2048];
	
public:
	seg_tree()
	{
		for(int i=0;i<2048;i++) data[i] = minf;
	}
	
	void __set(int pos, int val)
	{
		data[pos] = max(data[pos], val);
		pos /= 2;
		while(pos){
			data[pos] = max(data[pos*2], data[pos*2+1]);
			pos /= 2;
		}
	}
	
	void set(int pos, int val)
	{
		__set(pos + 1024, val);
	}
	
	int __query(int right)
	{
		int ret = minf;
		for(;right;){
			if(right&1){
				ret = max(ret, data[--right]);
			}
			right >>= 1;
		}
		return ret;
	}
	
	int query(int right)
	{
		return __query(right+1024);
	}
	
	int pos(int p)
	{
		return data[p+1024];
	}
};

seg_tree seg1[2001], seg2[2001];
int store[1001][1001];

class FoxSearchingRuins
{
	long long _x[1000], _y[1000], _v[1000];
	int x[1000], y[1000], v[1000];
	pair<int, pair<int, int> > y_sort[1000];
	int W, H, N, LR, goal, timeX, timeY;
	long long ret;
	
	void prepare(int p)
	{
		for(int i=0;i<=LR;i++){
			//int bval = max(seg1[i+x[p]].query(0, i+1), seg2[i-x[p]+1000].query(0, i+1));
			int bval = max((seg1[i+x[p]].query(i+1)),(seg2[i-x[p]+1000].query(i+1)));
			if(i==0) bval = max(bval, 0);
			if(bval!=minf) bval += v[p];
				
			bval = min(bval, goal);
			//printf("%d %d: %d\n", p, i, bval);
			
			store[p][i] = bval;
		}
	}
	
	void restore(int p)
	{
		for(int i=0;i<=LR;i++){
			if(store[p][i]==minf) continue;
			seg1[i+x[p]].set(i, store[p][i]);
			seg2[i-x[p]+1000].set(i, store[p][i]);
		}
	}
	
	void parse0(int p)
	{
		//printf("%d: %d %d %d\n", p, x[p], y[p], v[p]);
		for(int i=0;i<=LR;i++){
			//int bval = max(seg1[i+x[p]].query(0, i+1), seg2[i-x[p]+1000].query(0, i+1));
			int bval = (seg1[i+x[p]].query(i));
			if(i==0) bval = max(bval, 0);
			if(bval!=minf) bval += v[p];
				
			bval = min(bval, goal);
			//printf("%d %d: %d\n", p, i, bval);
			
			if(bval==minf) continue;
			seg1[i+x[p]].set(i, bval);
		}
	}
	
	void parse1(int p)
	{
		for(int i=0;i<=LR;i++){
			//int bval = max(seg1[i+x[p]].query(0, i+1), seg2[i-x[p]+1000].query(0, i+1));
			int bval = seg2[i-x[p]+1000].query(i);
			if(i==0) bval = max(bval, 0);
			if(bval!=minf) bval += v[p];
				
			bval = min(bval, goal);
			//printf("%d %d: %d\n", p, i, bval);
			
			if(bval==minf) continue;
			seg2[i-x[p]+1000].set(i, bval);
		}
	}
	
	void parse01(int p)
	{
		//printf("%d: %d %d %d\n", p, x[p], y[p], v[p]);
		for(int i=0;i<=LR;i++){
			int bval = max(seg1[i+x[p]].query(i+1), seg2[i-x[p]+1000].query(i+1));
			if(i==0) bval = max(bval, 0);
			if(bval!=minf) bval += v[p];
				
			bval = min(bval, goal);
			//printf("%d %d: %d\n", p, i, bval);
			
			if(bval==minf) continue;
			seg1[i+x[p]].set(i, bval);
			seg2[i-x[p]+1000].set(i, bval);
		}
	}
	
	void adjust(int p)
	{
		for(int i=0;i<=LR;i++){
			int bval = max(seg1[i+x[p]].pos(i), seg2[i-x[p]+1000].pos(i));
			if(bval >= goal){
				ret = min(ret, i*timeX + _y[p]*timeY);
			}
			if(bval==minf) continue;
			seg1[i+x[p]].set(i, bval);
			seg2[i-x[p]+1000].set(i, bval);
		}
	}
	
public:
	long long theMinTime(int _W, int _H, int _N, int _LR, int _goal, int _timeX, int _timeY, vector<int> seeds)
	{
		ret = -minf_l;
		W = _W; H = _H; N = _N; LR = _LR; goal = _goal; timeX = _timeX; timeY = _timeY;
		
		_x[0] = (seeds[1] * (long long)seeds[0] + seeds[2]) % W;
		_y[0] = (seeds[4] * (long long)seeds[3] + seeds[5]) % H;
		_v[0] = (seeds[7] * (long long)seeds[6] + seeds[8]) % seeds[9];
		for(int i=1;i<N;i++){
			_x[i] = (seeds[1] * _x[i-1] + seeds[2]) % W;
			_y[i] = (seeds[4] * _y[i-1] + seeds[5]) % H;
			_v[i] = (seeds[7] * _v[i-1] + seeds[8]) % seeds[9];
		}
		for(int i=0;i<N;i++){
			x[i] = (int)_x[i];
			y[i] = (int)_y[i];
			v[i] = (int)_v[i];
			//printf("%d %d %d\n", x[i], y[i], v[i]);
			y_sort[i] = make_pair(y[i], make_pair(x[i], i));
		}
		for(int i=0;i<N;i++){
			for(int j=0;j<i;j++){
				if(x[i]==x[j]&&y[i]==y[j]){
					v[i] += v[j];
					if(v[i]>=goal) v[i] = goal;
					v[j] = 0;
				}
			}
		}
		sort(y_sort, y_sort + N);
		
		int pos = 0;
		while(pos < N){
			int last = pos;
			while(last < N && y_sort[pos].first == y_sort[last].first) last++;
			
			for(int i=pos;i<last;i++){
				prepare(y_sort[i].second.second);
			}
			for(int i=pos;i<last;i++){
				restore(y_sort[i].second.second);
			}
			if(last-pos>1){
				for(int i=pos;i<last;i++){
					parse1(y_sort[i].second.second);
				}
				for(int i=last-1;i>=pos;i--){
					parse0(y_sort[i].second.second);
				}
			}
			for(int i=pos;i<last;i++){
				adjust(y_sort[i].second.second);
			}
			
			pos = last;
		}
		
		if(ret == -minf_l) return -1;
		return ret;
	}
};