PS/백준

[백준] 16236번 아기 상어 (python)

yo0oni 2024. 7. 15. 20:20

https://www.acmicpc.net/problem/16236


6개월 전에 풀었던 문젠데 한 번 더 풀었다.. 문제만 익숙했고 푸는 방식은 여전히 헷갈렸다.

 

헷갈렸던 점

  • 먹을 수 있는 물고기가 한 마리보다 많으면, 거리가 가장 가까운 물고기를 먹는다.
  • 거리가 가까운 물고기 많다면, 가장 위에 있는 물고기, 그러한 물고기도 많다면, 가장 왼쪽에 있는 물고기를 먹는다.

사실상 이 조건 때문에 골드 3인 거 같다. 나머지는 보통의 그래프 탐색 문제와 똑같다.

 

 

문제 풀이

문제에서 주어지는 기본 조건이 많기 때문에 미리 선언해야 한다. 상어의 크기, 상어가 먹은 물고기 개수, 상어의 위치, 총 소요 시간 등 미리 변수로 선언해 줬다.

 

그리고 처음 상어의 위치를 찾아줬다. 왜냐면 상어의 위치를 기점으로 물고기를 찾아야 하기 때문이다.

 

그 후에는 먹을 수 있는 물고기가 없을 때까지 반복문을 돌렸다. 이때 fish_list[0]은 3차원 리스트이기 때문에 상어의 x, y 좌표와 시간을 반환한다.

만약 상어 사이즈가 먹은 물고기 개수와 같아지면 상어의 크기를 +1 해주고, 먹은 개수는 다시 0으로 초기화해줘야 한다. 초기화를 안 해줬다가 틀렸다. 만약 사이즈가 2인데 두 마리를 먹었으면 3이 되고 먹은 개수는 0이 된다. 다시 세 마리를 먹어야 크기를 업그레이드할 수 있는 거다.

그리고 먹은 물고기의 위치는 0으로 바꿔준다. 이제 물고기가 없기 때문이다. 해당 물고기를 찾으러 간 시간도 반드시 총 소요시간에 더해준다.

 

find_fish 메서드는 다음과 같다.

단순한 bfs 로직이지만, 거리가 가까운 물고기부터 먹어야 하기 때문에 distance를 만들어 측정했다. 그 후 문제에서 요구하는 물고기 우선순위인 가장 가깝, 가장 위, 가장 왼쪽 순으로 정렬하였다.

이렇게 정렬하면 깔끔하게 물고기 우선순위를 뽑아서 반환할 수 있다.

 

Q. 정렬하면 시간초과 나지 않나요?

A. N이 20 이하입니다.

 


 

정답 코드

import sys
from collections import deque
input = sys.stdin.readline

dx = [1, 0, -1, 0]
dy = [0, -1, 0, 1]

def find_fish(shark_x, shark_y):
    global shark_size

    visited = [[False] * n for _ in range(n)]
    distance = [[0] * n for _ in range(n)]
    fish_can_eat = []

    dq = deque([(shark_x, shark_y)])

    while dq:
        x, y = dq.popleft()

        for i in range(4):
            nx = x + dx[i]
            ny = y + dy[i]

            if 0 <= nx < n and 0 <= ny < n and not visited[nx][ny]:
                if graph[nx][ny] <= shark_size:
                    visited[nx][ny] = True
                    distance[nx][ny] = distance[x][y] + 1
                    dq.append([nx, ny])

                    if graph[nx][ny] != 0 and graph[nx][ny] < shark_size:
                        fish_can_eat.append([nx, ny, distance[nx][ny]])

    fish_can_eat.sort(key = lambda x : (x[2], x[0], x[1]))
    return fish_can_eat


n = int(input())
graph = [list(map(int, input().split())) for _ in range(n)]
shark_x, shark_y = 0, 0
shark_size, shark_eat = 2, 0
total_time = 0
fish_list = []

for i in range(n):
    for j in range(n):
        if graph[i][j] == 9:
            shark_x, shark_y = i, j
            graph[i][j] = 0

while True:
    fish_list = find_fish(shark_x, shark_y)
    
    if len(fish_list) == 0:
        break

    shark_x, shark_y, time = fish_list[0]
    shark_eat += 1

    if shark_size == shark_eat:
        shark_eat = 0
        shark_size += 1

    graph[shark_x][shark_y] = 0
    total_time += time

print(total_time)