[BOJ] 14889번: 스타트와 링크 (조합론, 백트래킹)

728x90

 

(( 한 번 더 풀기 !!!!!!!!!!!!!!!!!!!! 무조건 )) 

 

문제 링크

https://www.acmicpc.net/problem/14889

 

백트래킹을 사용해서 풀 수 있는 문제라고 분류되어 있지만, 안써도 충분히 풀림

그치만 백트래킹 풀이와 조합 풀이 둘 다 해보기 !! (itertools 사용 못 할 수도 있으니)

 

 

TIL

  • 처음에 조합론을 생각함
  • 그치만 backtrack을 공부하고 있었기에 혹시 가능할까 싶어 풀어봄 ..
  • 결론은 조합론 + 백트래킹 or 조합론 정도가 적당할듯
  • 주어진 원소의 후보군을 다룰 땐, 원소 자체를 판별하는 것이 좋음 (격자 이런거 만들지 말고)
  • 반복되는 기능을 함수로 따로 뺄 때는 꼭 초기화 주의 !!!!

 

 

[ 처음 코드 ] - 이미 팀에 들어간 사람을 고려하지 않아서 중복 문제 발생 > (1,3), (1,2) 둘 다 가능하게 되버림

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

N = int(input())
array = [list(map(int, input().split())) for _ in range(N)]

min_ans = float("inf")
sum_all = sum(x for y in array for x in y)

dir = [(0, 1), (1, 0)]    # 아래, 오른쪽으로만 이동 (조합: 절반만 볼 것)
visited = [[False]*N for _ in range(N)]

def backtrack(cur, cy, cx, n):
    global min_ans
    if n == N//2:
        other = sum_all - cur
        diff = abs(cur - other)
        min_ans = min(min_ans, diff)
        return

    for dy, dx in dir:
        i, j = cy+dy, cx+dx
        if 0 <= i < N and 0 <= j < N and not visited[i][j] and i > j:
            visited[i][j] = True
            visited[j][i] = True
            backtrack(cur+array[i][j]+array[j][i], i, j, n+1)

            visited[i][j] = False
            visited[j][i] = False
            backtrack(cur, i, j, n)



backtrack(0, 0, 0, 0)

print(min_ans)
  • 좋지 않은 접근 방식
    • 주어진 원소의 후보군을 다룰 땐, 원소 자체를 판별하는 것이 좋음 (격자 이런거 만들지 말고)
  • 결과값이 계속 더 크게 나온 이유
    • 이미 팀에 들어간 사람을 고려하지 않아서 중복 문제 발생
    • (1,3)으로 이미 처리된 1과 3인데, (1,2)가 가능해져버림

 

[  첫 번째 문제 고친 두 번째 코드 ] - 이번엔 dir 방식으로 좌표로 계산하고자 해서 전체를 경우를 순회하지 못해 틀림 

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

N = int(input())
array = [list(map(int, input().split())) for _ in range(N)]

min_ans = float("inf")
sum_all = sum(x for y in array for x in y)
used = set()

dir = [(0, 1), (1, 0), (-1,0), (0,-1)]    # 아래, 오른쪽으로만 이동 (조합: 절반만 볼 것)
# set을 이용해 이동하기 때문에 아래랑 위만 봐버리면 중간에 보다가 멈춰버림 ㅠㅠ 
# 그냥 dir을 사용하는 방식 자체가 문제 ,, 0,0 기준으로 들어왔을 때 생각해보면 1,0까지만 조회하고 멈춰버림.
visited = [[False]*N for _ in range(N)]

def backtrack(cur, cy, cx, n):
    global min_ans
    if n == N//2:
        other = sum_all - cur
        diff = abs(cur - other)
        min_ans = min(min_ans, diff)
        return
    print("cur", cur)
    for dy, dx in dir:
        i, j = cy+dy, cx+dx
        if 0 <= i < N and 0 <= j < N and i not in used and j not in used and i > j:
            # visited[i][j] = True
            # visited[j][i] = True
            used.add(i)
            used.add(j)
            backtrack(cur+array[i][j]+array[j][i], i, j, n+1)

            # visited[i][j] = False
            # visited[j][i] = False
            used.remove(i)
            used.remove(j)
            # backtrack(cur, i, j, n)

backtrack(0, 0, 0, 0)

print(min_ans)
  • 위 첫 번째 코드의 문제점을 잡고자 set을 도입해 파악하고자 함
  • set을 도입해 파악하면 dir 방식으로 격자를 순회하는 것에 문제가 발생
    • (0,0) > (1,0) 에서 0과 1이 set에 추가됨
    • (1,0) > (1,1) 탈락 (이미 1이 set에 존재)
    • (1,0) > (2,0) 탈락 (이미 0이 set에 존재)
    • 이러면 순회가 끝나버림

 

[ 이걸 고치기 위해 dir이 아닌 for loop으로 모든 범위에 대해 순회 ]

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

N = int(input())
array = [list(map(int, input().split())) for _ in range(N)]

min_ans = float("inf")
sum_all = sum(array[i][j] for i in range(N) for j in range(N))
used = set()

def backtrack(cur, n):
    global min_ans
    if n == N // 2:
        other = sum_all - cur
        diff = abs(cur - other)
        min_ans = min(min_ans, diff)
        return

    for i in range(N):
        for j in range(i):
            if i not in used and j not in used:
                used.add(i)
                used.add(j)
                backtrack(cur + array[i][j] + array[j][i], n + 1)
                used.remove(i)
                used.remove(j)
                # backtrack(cur, n) >> 무한루프 발생 (이거 안해도 어차피 모든 경우의 수 가능 조합이니까)

backtrack(0, 0)
print(min_ans)
  • 애초에 sum_all - cur 방식이 문제 >> 현재 sum_all 정의가 잘못됐음
    • 팀 A 안의 쌍도 포함됨 ✅
    • 팀 B 안의 쌍도 포함됨 ✅
    • ❗ 그리고 A팀 사람이랑 B팀 사람이 함께 있는 (i, j) 쌍도 포함됨 ❗
    • 예를 들어 (1,2) > A , (0,3) > B 인 상황에서 sum_all 에 (1,3) 도 포함되어 버림

 

 

 

Code

(( 메모리 사용은 거의 비슷한데, 시간은 백트래킹이 훨씬 빠름 ))

 

[ 조합론 코드 ]

from itertools import combinations
import sys
input = sys.stdin.readline

N = int(input())
array = [list(map(int, input().split())) for _ in range(N)]


def sum_power(members):
    visited = [[False]*N for _ in range(N)]  # >> 처음에 이걸 전역변수로 둬서 초기화되지 않아서 틀림
    ret = 0
    for m in members:
        for i in range(N):
            if i in members and not visited[i][m] and not visited[m][i]:
                ret += array[i][m] + array[m][i]
                visited[i][m] = True
                visited[m][i] = True
    return ret

candidates = combinations([i for i in range(N)], N//2)  # 반만 알면, 반은 알아서 결정
min_diff = float("inf")
all_set = set(i for i in range(N))

for c in candidates:
    team1 = set(c)
    team2 = all_set-team1
    min_diff = min(abs(sum_power(team1)-sum_power(team2)), min_diff)
    if min_diff == 0:
        print(0)
        exit()

print(min_diff)
  • 반복되는 기능을 함수로 따로 뺄 때는 꼭 초기화 주의 !!!!
    • 예를 들어, 매개변수를 받아서 값을 계산해서 돌려주는 함수와 같은 경우 꼭 초기화 확인해야 함

[ 백트래킹 코드 ] 

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

N = int(input())
S = [list(map(int, input().split())) for _ in range(N)]

min_diff = float('inf')
selected = [False] * N  # 어떤 사람이 스타트 팀에 들어갔는지 여부

def calculate_diff():
    start_team = []
    link_team = []
    for i in range(N):
        if selected[i]:
            start_team.append(i)
        else:
            link_team.append(i)

    start_score = 0
    link_score = 0

    for i in range(N // 2):
        for j in range(i + 1, N // 2):
            s1, s2 = start_team[i], start_team[j]
            l1, l2 = link_team[i], link_team[j]
            start_score += S[s1][s2] + S[s2][s1]
            link_score += S[l1][l2] + S[l2][l1]

    return abs(start_score - link_score)

def backtrack(idx, count):
    global min_diff

    # 종료 조건: N/2명을 뽑았을 때
    if count == N // 2:
        diff = calculate_diff()
        min_diff = min(min_diff, diff)
        return

    for i in range(idx, N):
        selected[i] = True
        backtrack(i + 1, count + 1)
        selected[i] = False

# 0번은 무조건 스타트 팀에 넣음 → 대칭 제거
selected[0] = True
backtrack(1, 1)
print(min_diff)
  • 탐색 절반으로 줄이기 (대칭 제거)
    • 스타트 팀으로 [0, 1, 2, ..., N//2 - 1]을 뽑는 것과 링크 팀으로 [0, 1, 2, ..., N//2 - 1]을 뽑는 것은 완전히 같은 경우
    • 즉, 첫 번째 인덱스(0번)는 무조건 스타트 팀에 넣고 시작하면 전체 경우의 수를 절반으로 줄일 수 있음!

예) N = 4 (총 사람 4명: 0, 1, 2, 3)

  • 이들을 두 팀으로 나누는 경우
    • 경우 1: 스타트 팀 [0, 1], 링크 팀 [2, 3]
    • 경우 2: 스타트 팀 [2, 3], 링크 팀 [0, 1]
  • 이 두 개는 똑같은 경우임
    • 왜? 팀 이름만 바뀐 거고, 능력치 차이는 똑같음
    • 즉, 순열의 관점에서는 다르지만, 이 문제에서는 “누가 스타트 팀이고 누가 링크 팀인지”는 중요하지 않아.

 

“대칭 가지치기 (Symmetry Pruning)” 또는 “중복 상태 제거 (Avoiding Equivalent States)”

이유 1: 시간 줄이기

  • 예를 들어 N=20이면, 조합은 C(20, 10) → 184,756가지
  • 근데 사실상 절반은 “이름만 바꾼 같은 팀”
  • 그래서 절반인 C(19, 9)만 보면 됨 → 약 92,378개
  • 2배 빨라짐

이유 2: 같은 해답 여러 번 구하지 않게 하기

  • 문제는 "능력치 차이" 같은 값만 중요할 때가 많음
  • 팀 이름, 순서는 중요하지 않음 → 중복 제거 필요

언제 주로 사용하는지

유형 예시
팀 나누기 스타트팀 vs 링크팀 (이 문제처럼)
부분 집합 만들기 A와 전체 - A는 사실 같은 경우
그래프에서 노드 고르기 특정 노드를 기준으로 시작하면 중복 제거
순열 만들기 같은 숫자가 여러 개 있을 때 중복 방지
N-Queen 문제 대칭 판을 미리 배제함