PS/BOJ

[BOJ] 백준 24520. Meet In The Middle (Platinum IV)

kth990303 2022. 3. 25. 03:38
반응형

알고리즘 중급 스터디에서 과제로 해결해야 했던 문제.

포인트를 놓쳐 생각보다 굉장히 많이 삽질했다.

문제는 아래와 같다.

https://www.acmicpc.net/problem/24520

 

24520번: Meet In The Middle

첫 번째 줄에 마을의 수 $N$, 약속의 수 $K$가 주어진다. $(1 \le N, K \le 100\,000)$ 이어지는 줄부터 $N-1$개의 줄에 도로 정보를 나타내는 세 정수 $u$, $v$, $w$가 주어진다. $u$번 마을과 $v$번 마을 사이에

www.acmicpc.net


의식의 흐름 및 해설

N이 10만이기 때문에 DFS O(N)으로 모든 노드를 훑으면 시간초과이다.

따라서 특정 노드만 확인해주면 되는데, 정점들 사이의 거리는 LCA로 O(logn)에 구할 수 있음이 자명하다. 

이 때의 거리를 d라고 하면, 

d가 홀수일 땐 당연히 -1,

d가 짝수일 땐 특정 노드로부터 d/2만큼 떨어진 위치의 노드가 존재하는지 확인해주면 된다.

 

처음에는 더 멀리 있는 노드를 lca로 올려 다시 solve(lca, 가까이 있는 노드)로 진행해서 해결하려 했으나,

2^i번째와 2^(i+1)번째 사이의 노드가 정답인 경우를 놓쳐서 삽질을 많이 했다.

7번 노드의 2^1번째 부모, 2^2번째 부모의 사이에 있는 4번 노드가 답인 경우

중간 경우를 찾아주기 위해 특정 경우에서 추가로 재귀 및 이분탐색으로 해결해주었다.


시행착오

아래와 같이 루트로부터 거리가 먼 노드를 lca(n1, n2)쪽으로 계속 올려주는 방법은

인덱스를 벗어날 위험성도 높을 뿐더러, 2^i번째와 2^(i+1)번째 사이의 노드를 검증하지 못한다.

if(dist1==dist2)return lca(n1,n2);
else if(dist1>dist2)return solve(n1,lca(n1,n2));
else return solve(lca(n1,n2),n2);

코드

#include <bits/stdc++.h>
#define sz(v) (int)v.size()
#define all(v) (v).begin(), (v).end()
#define press(v) (v).erase(unique(all(v)), (v).end())
using namespace std;
typedef long long ll;
typedef pair<int, int> pi;
typedef pair<int,pi> pii;
typedef pair<ll, ll> pl;
const int MAX = 1e5+3;
const int SIZE = 17;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
ll N,K,d[MAX],p[MAX][SIZE + 1],dis[MAX],dist,idx;
vector<pl> v[MAX];
bool vis[MAX];
void init(int cur) {
    for (auto i : v[cur]) {
        if (d[i.second] == -1) {
            d[i.second] = d[cur] + 1;
            p[i.second][0] = cur;
            dis[i.second]=dis[cur]+i.first;
            init(i.second);
        }
    }
}
ll lca(int a, int b) {
    if (d[a] < d[b])
        swap(a, b);
    ll diff = d[a] - d[b];
    int j = 0;
    while (diff) {
        if (diff % 2)
            a = p[a][j];
        diff /= 2;
        j++;
    }
    if (a == b)
        return a;
    for (int j = SIZE; j >= 0; j--) {
        if (p[a][j] != -1 && p[a][j] != p[b][j]) {
            a = p[a][j];
            b = p[b][j];
        }
    }
    a = p[a][0];
    return a;
}
ll solve(ll s, ll e){
    while(s<=e){
        ll mid=(s+e)/2;
        if(p[idx][mid]==-1){
            return solve(s,mid-1);
        }
        ll distance=dis[idx]-dis[p[idx][mid]];
        if(distance==dist)return p[idx][mid];
        else if(distance<dist){
            idx=p[idx][mid];
            dist-=distance;
            return solve(0,e);
        }
        else return solve(s,mid-1);
    }
    return -1;
}
int main() {
    cin.tie(0)->sync_with_stdio(0);
    cin>>N>>K;
    for(int i=0;i<N-1;i++){
        ll n1,n2,cost;
        cin>>n1>>n2>>cost;
        v[n1].push_back({cost,n2});
        v[n2].push_back({cost,n1});
    }
    memset(p, -1, sizeof(p));
    fill(d, d + MAX, -1);
    d[1]=0;
    init(1);
    for (int j = 0; j < SIZE; j++) {
        for (int i = 1; i <= N; i++) {
            if (p[i][j] != -1)
                p[i][j + 1] = p[p[i][j]][j];
        }
    }
    while(K--){
        int n1,n2;
        cin>>n1>>n2;
        int node=lca(n1,n2);
        ll dist1=dis[n1]-dis[node];
        ll dist2=dis[n2]-dis[node];
        if(dist1==dist2){
            cout<<node<<"\n";
            continue;
        }
        if((dist1+dist2)%2) {
            cout<<-1<<"\n";
            continue;
        }
        dist=(dist1+dist2)/2;
        dist1>dist2?idx=n1:idx=n2;
        cout<<solve(0,SIZE)<<"\n";
    }
}

플레티넘4 또는 플레티넘3 정도인 듯하다.

사실 지난번에 시도했다가 WA를 받은 문제여서 문제가 기억에 남아 새벽에 빨리 풀고 자려했는데,

알고리즘을 잘못 설계했었어서 생각보다 더 오래걸렸다.

 

sparse_table이 O(nlgn)이긴 하지만, 그만큼 건너뛰는 노드들이 있다는 사실을 명심하자.

반응형