<

PS/알고리즘

[Algorithm/C++] 1,2차원에서의 구간 합 구하기 (+백준 연습문제, kakao 기출문제)

leedongbin 2023. 8. 4. 17:23

이번 글에서는, 아래의 네 가지 조건으로 만들어지는 다양한 상황에서 알맞은 방식으로 구간 합을 구하는 테크닉을 소개하려고 합니다.

  • 1차원 공간 vs 2차원 공간
  • 점 업데이트 vs 구간 업데이트
  • 점 쿼리 vs 구간 쿼리 (특정 위치에 대한 개수만을 묻는가? vs 어떤 구간의 총합을 묻는가?)
  • 오프라인 쿼리 vs 온라인 쿼리 (질문을 업데이트가 모두 끝나고 하는가? vs 업데이트 도중에 질문을 하는가?)

 

문제는 난이도 순이 아니지만, 순서대로 읽어보시면 더 쉽게 이해하실 수 있습니다.

3, 4, 8번 문제는 세그먼트 트리에 대한 사전 지식을 필요로합니다.
세그먼트 트리에 관한 좋은 글이 이미 많기 때문에, 관련 내용은 링크로 대체하겠습니다.
https://blog.naver.com/kks227/220791986409 (세그먼트 트리)
https://blog.naver.com/kks227/220824350353 (세그먼트 트리 with lazy propagation)

1. 구간 합 구하기 4 (백준 11659)

  • 1차원 공간
  • 점 업데이트 (초기 값만 세팅하는 경우도 점 업데이트로 볼 수 있습니다.)
  • 오프라인 쿼리

가장 기본적인 형태의 문제입니다.

구간 $[1, x]$의 총 합 $(= sum[x])$을 누적 합으로 미리 구해준 뒤, 구간 $[i, j]$의 합은 $sum[j] - sum[i-1]$로 쉽게 구할 수 있습니다.

시간 복잡도는 누적 합을 구하는 데 $O(N)$, $M$개의 쿼리를 처리하는 데 각각 $O(1)$이므로

$O(N+M)$입니다.

코드

#include<bits/stdc++.h>
using namespace std;

int main(){
    int n,m;scanf("%d%d",&n,&m);
    vector<int> a(n+1),sum(n+1);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        sum[i]=sum[i-1]+a[i];
    }
    while(m--){
        int i,j;scanf("%d%d",&i,&j);
        printf("%d\n",sum[j]-sum[i-1]);
    }
    return 0;
}

2. 시간 구간 다중 업데이트 다중 합 (백준 25827)

  • 1차원 공간
  • 구간 업데이트
  • 오프라인 쿼리

1번 문제에서 구간 업데이트로 확장되었습니다.

오프라인 쿼리인 상황에서는, imos기법(누적합 트릭)을 사용하면 구간 업데이트를 $O(1)$에 할 수 있습니다!

pc방에 머무른 손님들의 정보를 그림으로 표현했다고 가정해 보겠습니다.

이처럼 다양한 시간 구간이 주어졌을 때, 각각의 구간$[s, e]$에 대해서 for문을 돌리는 게 아니라,

왼쪽부터 쭉 살펴보면 선분이 시작하는 부분 s에서 +1(선분 추가), 끝나는 부분 e+1에서 -1(선분 삭제)라는 정보만 알면 됩니다. (구간이 [s, e]인 경우 e+1에서, [s, e)인 경우 e에서 삭제하면 됩니다.)

따라서 위 그림처럼, 한 선분을 업데이트할 때 화살표 두 개(입장기록, 퇴장기록)만 업데이트해주면 됩니다.

이 정보를 바탕으로 왼쪽부터 누적합을 구해주면, 그 위치(시각)에서의 선분의 개수(손님의 수)가 됩니다!

이제 각 위치에서의 선분의 개수를 알고 있으므로, 1번 문제와 똑같이 누적 합을 구해주면 구간 쿼리를 수행할 수 있습니다.

이때 $sum[x]$는, [pc방 오픈 시점 ~ 시각 x]까지 손님들이 사용한 총전력량 정도에 비유할 수 있겠네요.

시간 복잡도는 누적 합을 두 번 구하는 데 각각 $O(maxTime)$, $M$개의 쿼리를 처리하는 데 각각 $O(1)$이므로

$O(maxTime+M)$입니다.

코드

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int sz=24*60*60;
ll a[sz+2],sum[sz+2],b;

void init(){
    if(b)return; b=1; //2번 쿼리가 수행되는 처음 한 번만 초기화.
    for(int t=1;t<=sz;t++){
        a[t]+=a[t-1];
        sum[t]=sum[t-1]+a[t];
    }
}

void add(int s,int e){
    a[s]++,a[e]--;
}

int main(){
    int n;scanf("%d",&n);
    while(n--){
        int q,h1,m1,s1,h2,m2,s2;
        scanf("%d %d:%d:%d %d:%d:%d",&q,&h1,&m1,&s1,&h2,&m2,&s2);
        int s=h1*3600+m1*60+s1,e=h2*3600+m2*60+s2;
        s++,e++;//1부터 시작하도록 1칸 이동.
        if(q==1){
            add(s,e);
        }
        else{
            init();
            printf("%lld\n",sum[e-1]-sum[s-1]);
        }
    }
    return 0;
}

3. 구간 합 구하기 (백준 2042)

  • 1차원 공간
  • 점 업데이트
  • 온라인 쿼리

1번 문제에서 바뀐 건 마지막 조건뿐인데, 세그먼트 트리의 기본 문제가 되었습니다.

시간 복잡도는 초기 $N$개의 값 세팅에 각각 $O(logN)$, $M$번 업데이트 하는 데 각각 $O(logN)$, $K$개의 쿼리를 처리하는 데 각각 $O(logN)$이므로

$O((N+M+K)logN)$입니다. (초기 $N$개의 값 세팅은 총 $O(N)$에 구하는 방법도 있습니다.)

코드

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int sz=1<<20;
ll arr[sz<<1];

void update(int i,ll val){
    i+=sz,arr[i]=val;
    while(i>1)i>>=1,arr[i]=arr[i<<1]+arr[i<<1|1];
}

ll query(int s,int e,int node,int ns,int ne){
    if(e<ns||ne<s)return 0;
    if(s<=ns&&ne<=e)return arr[node];
    int mid=(ns+ne)>>1;
    return query(s,e,node<<1,ns,mid)+query(s,e,node<<1|1,mid+1,ne);
}ll query(int s,int e){return query(s,e,1,0,sz-1);}

int main(){
    int n,m,k;
    scanf("%d %d %d",&n,&m,&k);
    for(int i=1;i<=n;i++){
        ll x;scanf("%lld",&x);
        update(i,x);
    }
    int q=m+k;
    while(q--){
        int a;scanf("%d",&a);
        if(a==1){
            int b;ll c;scanf("%d%lld",&b,&c);
            update(b,c);
        }
        else{
            int b,c;scanf("%d%d",&b,&c);
            printf("%lld\n",query(b,c));
        }
    }
    return 0;
}

4. 구간 합 구하기 2 (백준 10999)

  • 1차원 공간
  • 구간 업데이트
  • 온라인 쿼리

3번 문제에서 구간 업데이트로 확장되었고, segment tree with lazy propagation의 기본 문제가 되었습니다.

시간 복잡도는 초기 $N$개의 값 세팅에 각각 $O(logN)$, $M$번 업데이트 하는 데 각각 $O(logN)$, $K$개의 쿼리를 처리하는 데 각각 $O(logN)$이므로

$O((N+M+K)logN)$입니다. (마찬가지로 초기 $N$개의 값 세팅은 총 $O(N)$에 구하는 방법도 있습니다.)

코드

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int sz=1<<20;
ll arr[sz<<1],lazy[sz<<1];

void propagate(int node,int ns,int ne){
    if(!lazy[node])return;
    if(node<sz){
        lazy[node<<1]+=lazy[node];
        lazy[node<<1|1]+=lazy[node];
    }
    arr[node]+=lazy[node]*(ll)(ne-ns+1);
    lazy[node]=0;
}

void update(int s,int e,ll k,int node,int ns,int ne){
    propagate(node,ns,ne);
    if(e<ns||ne<s)return;
    if (s<=ns&&ne<=e){
        lazy[node]+=k;
        propagate(node,ns,ne);return;
    }
    int mid=(ns+ne)>>1;
    update(s,e,k,node<<1,ns,mid),update(s,e,k,node<<1|1,mid+1,ne);
    arr[node]=arr[node<<1]+arr[node<<1|1];
}void update(int s,int e,ll k){update(s,e,k,1,0,sz-1);}

ll query(int s,int e,int node,int ns,int ne){
    propagate(node,ns,ne);
    if(e<ns||ne<s)return 0;
    if(s<=ns&&ne<=e)return arr[node];
    int mid=(ns+ne)>>1;
    return query(s,e,node<<1,ns,mid)+query(s,e,node<<1|1,mid+1,ne);
}ll query(int s,int e){return query(s,e,1,0,sz-1);}

int main(){
    int n,m,k;
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=n;i++){
        ll x;scanf("%lld",&x);
        update(i,i,x);
    }

    int q=m+k;
    while(q--){
        int a;scanf("%d",&a);
        if(a==1){
            int b,c;ll d;scanf("%d%d%lld",&b,&c,&d);
            update(b,c,d);
        }
        else{
            int b,c;scanf("%d%d",&b,&c);
            printf("%lld\n",query(b,c));
        }
    }
    return 0;
}

5. 구간 합 구하기 5 (백준 11660)

  • 2차원 공간
  • 점 업데이트 (초기 값만 세팅하는 경우도 점 업데이트로 볼 수 있습니다.)
  • 오프라인 쿼리

1번 문제에서 2차원 공간으로 확장되었습니다.

구간 $[(1,1), (x, y)]$의 총 합 $(= sum[x][y])$을 누적 합으로 미리 구해준 뒤, 구간 $[(x1, y1), (x2, y2)]$의 합은 sum[x2][y2] - sum[x1-1][y2] - sum[x2][y1-1] + sum[x1-1][y1-1]로 구할 수 있습니다. (누적 합을 미리 구하는 과정도 이와 비슷합니다.)

시간복잡도는 초기 $N^{2}$개의 값 세팅에 각각 $O(1)$, $M$개의 쿼리를 처리하는 데 각각 $O(1)$이므로

$O(N^{2}+M)$입니다.

코드

#include<bits/stdc++.h>
using namespace std;
int a[1025][1025],sum[1025][1025];

int main(){
    int n,m;scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){
        scanf("%d",&a[i][j]);
        sum[i][j]=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1]+a[i][j];
    }
    while(m--){
        int x1,y1,x2,y2;scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
        printf("%d\n",sum[x2][y2]-sum[x1-1][y2]-sum[x2][y1-1]+sum[x1-1][y1-1]);
    }
    return 0;
}

6. 파괴되지 않은 건물 (2022 KAKAO BLIND RECRUITMENT)

  • 2차원 공간
  • 구간 업데이트
  • 점 쿼리
  • 오프라인 쿼리

2번 문제에서 2차원 공간으로 확장된 경우입니다.

이 글에서 유일하게 점 쿼리인 문제이고, 특정 건물(점)의 파괴 여부를 묻는 상황입니다.

1차원 공간에서는 한 선분을 업데이트할 때 화살표 두 개를 업데이트했지만, 2차원 공간에서는 그림과 같이 네 개의 위치를 업데이트해주어야 합니다. 나머지는 5번 문제와 똑같은 방식으로 누적합을 구해주면 됩니다. 점 쿼리이므로 $sum$ 배열은 구할 필요가 없습니다.

시간복잡도는 초기 $N*M$개의 값 세팅에 각각 $O(1)$, $Q$개의 쿼리를 처리하는 데 각각 $O(1)$, 누적 합을 구하는 데 $O(NM)$이므로

$O(NM+len(skill))$입니다.

코드

#include<bits/stdc++.h>
using namespace std;
int a[1002][1002],sum[1002][1002];
int answer(int n,int m){
    int ans=0;
    for(int i=1;i<=n;i++)for(int j=1;j<=m;j++){
        a[i][j]+=a[i][j-1]+a[i-1][j]-a[i-1][j-1];
        if(a[i][j]>0)ans++;
    }
    return ans;
}

void add(int i1,int j1,int i2,int j2,int k){
    a[i1][j1]+=k,a[i2+1][j1]-=k,a[i1][j2+1]-=k,a[i2+1][j2+1]+=k;
}

int solution(vector<vector<int>> board, vector<vector<int>> skill) {
    int n=board.size(),m=board[0].size();
    for(int r=0;r<n;r++)for(int c=0;c<m;c++)
        add(r+1,c+1,r+1,c+1,board[r][c]);
    for(auto i:skill){
        int type=i[0],r1=i[1]+1,c1=i[2]+1,r2=i[3]+1,c2=i[4]+1,degree=i[5];
        if(type==1)degree*=-1;
        add(r1,c1,r2,c2,degree);
    }
    return answer(n,m);
}

7. 2차원 배열 다중 업데이트 다중 합 (백준 25978)

  • 2차원 공간
  • 구간 업데이트
  • 오프라인 쿼리

6번 문제에서 구간 쿼리로 확장되었습니다.

6번 문제와 같이 배열 a를 빠르게 구해주고, 5번 문제와 똑같이 배열 a를 가지고 한 번 더 누적 합$(sum)$을 구해주면 됩니다.

시간복잡도는 초기 $N^{2}$개의 값 세팅에 각각 $O(1)$, $M$개의 쿼리를 처리하는 데 각각 $O(1)$, 누적 합을 구하는 데 $O(N^{2})$이므로

$O(N^{2}+M)$입니다.

코드

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll a[1002][1002],sum[1002][1002],b;

void init(int n){
    if(b)return; b=1;
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){
        a[i][j]+=a[i][j-1]+a[i-1][j]-a[i-1][j-1];
        sum[i][j]=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1]+a[i][j];
    }
}

void add(int i1,int j1,int i2,int j2,int k){
    a[i1][j1]+=k,a[i2+1][j1]-=k,a[i1][j2+1]-=k,a[i2+1][j2+1]+=k;
}

int main(){
    int n,m;scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){
        int k;scanf("%d",&k);
        add(i,j,i,j,k);
    }

    while(m--){
        int q;scanf("%d",&q);
        if(q==1){
            int i1,i2,j1,j2,k;scanf("%d%d%d%d%d",&i1,&j1,&i2,&j2,&k);
            i1++,i2++,j1++,j2++;
            add(i1,j1,i2,j2,k);
        }
        else{
            init(n);
            int i1,i2,j1,j2;scanf("%d%d%d%d",&i1,&j1,&i2,&j2);
            i1++,i2++,j1++,j2++;
            printf("%lld\n",sum[i2][j2]-sum[i1-1][j2]-sum[i2][j1-1]+sum[i1-1][j1-1]);
        }
    }
    return 0;
}

8. 구간 합 구하기 3 (백준 11658)

  • 2차원 공간
  • 점 업데이트
  • 온라인 쿼리

3번 문제에서 2차원 공간으로 확장되었습니다.

세그먼트 트리를 2차원으로 구현하면 되는데, y의 범위를 미리 고정해 놓고 $log(N)$개의 y 구간들에 대해 1차원 세그먼트 트리처럼 업데이트해주면 됩니다. 자세한 설명은 코드로 대체합니다.

시간 복잡도는 초기 $N^{2}$개의 값 세팅에 각각 $O(log^{2}N)$, $M$개의 쿼리를 처리하는 데 각각 $O(log^{2}N)$이므로

$O((N^{2}+M)log^{2}N)$입니다.

코드 (2차원 세그먼트 트리)

#include<bits/stdc++.h>
using namespace std;
const int sz=1<<10;

int arr[sz<<1][sz<<1];

void update(int x,int y,int val){
    x+=sz,y+=sz;
    int init_x=x;
    arr[x][y]=val;
    while(x>1)x>>=1,arr[x][y]=arr[x<<1][y]+arr[x<<1|1][y];
    while(y>1){
        x=init_x;
        y>>=1,arr[x][y]=arr[x][y<<1]+arr[x][y<<1|1];
        while(x>1)x>>=1,arr[x][y]=arr[x<<1][y]+arr[x<<1|1][y];
    }
}

int y_fixed(int node_y,int xs,int xe,int node,int ns,int ne){
    if(xe<ns||ne<xs)return 0;
    if(xs<=ns&&ne<=xe)return arr[node][node_y];
    int mid=(ns+ne)>>1;
    return y_fixed(node_y,xs,xe,node<<1,ns,mid)+y_fixed(node_y,xs,xe,node<<1|1,mid+1,ne);
}int y_fixed(int node_y,int xs,int xe){return y_fixed(node_y,xs,xe,1,0,sz-1);}

int query(int xs,int ys,int xe,int ye,int node,int ns,int ne){
    if(ye<ns||ne<ys)return 0;
    if(ys<=ns&&ne<=ye)return y_fixed(node,xs,xe);
    int mid=(ns+ne)>>1;
    return query(xs,ys,xe,ye,node<<1,ns,mid)+query(xs,ys,xe,ye,node<<1|1,mid+1,ne);
}int query(int x1,int y1,int x2,int y2){return query(x1,y1,x2,y2,1,0,sz-1);}

int main(){
    int n,m;scanf("%d%d",&n,&m);
    for(int i=0;i<n;i++)for(int j=0;j<n;j++){
        int x;scanf("%d",&x);
        update(i,j,x);
    }
    while(m--){
        int w;scanf("%d",&w);
        if(!w){
            int x,y,c;scanf("%d%d%d",&x,&y,&c);
            x--,y--;
            update(x,y,c);
        }
        else{
            int x1,y1,x2,y2;scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
            x1--,y1--,x2--,y2--;
            printf("%d\n",query(x1,y1,x2,y2));
        }
    }
    return 0;
}

 

시간복잡도가 꽤 빡빡한데, 펜윅 트리(=Binary Indexed Tree, BIT)를 알고 있다면 5, 7번 문제에서 했던 것처럼 누적 합을 이용하여 약 1/3의 속도로 문제를 해결할 수 있습니다.

코드 (2차원 펜윅 트리)

#include<bits/stdc++.h>
using namespace std;
const int sz=1<<10;
int a[sz+1][sz+1],BIT[sz+1][sz+1];

int sum(int x,int y){
    int ret=0;
    for(int i=x;i;i-=i&-i)
        for(int j=y;j;j-=j&-j)
            ret+=BIT[i][j];
    return ret;
}

void update(int x,int y,int add){
    for(int i=x;i<=sz;i+=i&-i)
        for(int j=y;j<=sz;j+=j&-j)
            BIT[i][j]+=add;
}

int main(){
    int n,m;scanf("%d%d",&n,&m);
    for(int x=1;x<=n;x++)for(int y=1;y<=n;y++){
        scanf("%d",&a[x][y]);
        update(x,y,a[x][y]);
    }

    while(m--){
        int w;scanf("%d",&w);
        if(!w){
            int x,y,c;scanf("%d %d %d",&x,&y,&c);
            update(x,y,c-a[x][y]);
            a[x][y]=c;
        }
        else{
            int x1,y1,x2,y2;scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
            printf("%d\n",sum(x2,y2)-sum(x2,y1-1)-sum(x1-1,y2)+sum(x1-1,y1-1));
        }
    }
}