
DQN(Deep Q-Network): 딥러닝과 강화학습의 결합
(수정: 2026년 1월 1일 오전 08:33)
DQN(Deep Q-Network): 딥러닝과 강화학습의 결합
2013년 DeepMind가 발표한 DQN(Deep Q-Network) 은 딥러닝과 강화학습을 결합하여 아타리 게임에서 인간 수준의 성능을 달성했습니다. 이 글에서는 DQN의 핵심 개념과 구현 방법을 살펴봅니다.
1. Q-Learning 복습
Q-Value란?
Q-Value (Action-Value) 는 상태 s에서 행동 a를 취했을 때 얻을 수 있는 미래 보상의 기대값입니다.
γ (gamma): 할인율 (보통 0.99)
Q-Learning 업데이트
테이블 기반의 한계
# 테이블 기반 Q-Learning Q_table = {} # {(state, action): q_value} # 문제: 상태 공간이 클 때 # 아타리 게임 화면 = 210 x 160 x 3 픽셀 # 가능한 상태 수 = 256^(210*160*3) → 저장 불가능
2. DQN의 핵심 아이디어
신경망으로 Q-Value 근사
입력: 상태 (State) → [신경망] → 출력: 각 행동의 Q-Value
import torch.nn as nn class DQN(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.network = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) def forward(self, x): return self.network(x) # 사용 예시 model = DQN(state_dim=4, action_dim=2) state = torch.randn(1, 4) q_values = model(state) # [Q(s,a1), Q(s,a2)]
3. 학습 안정화 기법
단순히 신경망으로 Q-Value를 근사하면 학습이 불안정합니다. DQN은 두 가지 기법으로 이를 해결했습니다.
1) Experience Replay
에이전트가 겪은 경험 (s, a, r, s', done)을 버퍼에 저장하고, 무작위로 샘플링하여 학습합니다.
사용 이유:
- 연속된 데이터의 상관관계를 깨뜨림
- 같은 경험을 여러 번 재사용 가능
from collections import deque import random class ReplayBuffer: def __init__(self, capacity=10000): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) states, actions, rewards, next_states, dones = zip(*batch) return states, actions, rewards, next_states, dones def __len__(self): return len(self.buffer)
2) Target Network
Q-value 타겟을 계산할 때 별도의 네트워크를 사용합니다.
사용 이유:
- 타겟이 계속 변하면 학습이 불안정
- 타겟 네트워크는 주기적으로만 업데이트
# 메인 네트워크와 타겟 네트워크 policy_net = DQN(state_dim=4, action_dim=2) target_net = DQN(state_dim=4, action_dim=2) # 초기화: 같은 가중치로 시작 target_net.load_state_dict(policy_net.state_dict()) target_net.eval() # 타겟 네트워크는 학습하지 않음 # 주기적으로 타겟 네트워크 업데이트 def update_target_network(): target_net.load_state_dict(policy_net.state_dict())
4. DQN 알고리즘
1. 리플레이 버퍼, 메인 네트워크(Q), 타겟 네트워크(Q') 초기화 반복: 2. 현재 상태 s에서 ε-greedy로 행동 선택 - 확률 ε: 무작위 행동 - 확률 1-ε: argmax Q(s, a) 3. 행동 실행 → 보상 r, 다음 상태 s' 관찰 4. (s, a, r, s', done)을 버퍼에 저장 5. 버퍼에서 미니배치 샘플링 6. 손실 계산: y = r + γ * max Q'(s', a') (done이면 y = r) Loss = (Q(s, a) - y)² 7. 메인 네트워크 업데이트 8. 주기적으로 타겟 네트워크 동기화
5. PyTorch 구현
전체 코드
import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.optim as optim from collections import deque import random # 하이퍼파라미터 BATCH_SIZE = 64 GAMMA = 0.99 EPSILON_START = 1.0 EPSILON_END = 0.01 EPSILON_DECAY = 0.995 TARGET_UPDATE = 10 MEMORY_SIZE = 10000 LEARNING_RATE = 0.001 EPISODES = 500 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class DQN(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.network = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) def forward(self, x): return self.network(x) class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) states, actions, rewards, next_states, dones = zip(*batch) return ( torch.FloatTensor(np.array(states)).to(device), torch.LongTensor(actions).to(device), torch.FloatTensor(rewards).to(device), torch.FloatTensor(np.array(next_states)).to(device), torch.FloatTensor(dones).to(device) ) def __len__(self): return len(self.buffer) class DQNAgent: def __init__(self, state_dim, action_dim): self.action_dim = action_dim self.epsilon = EPSILON_START self.policy_net = DQN(state_dim, action_dim).to(device) self.target_net = DQN(state_dim, action_dim).to(device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) self.memory = ReplayBuffer(MEMORY_SIZE) def select_action(self, state): if random.random() < self.epsilon: return random.randrange(self.action_dim) else: with torch.no_grad(): state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) q_values = self.policy_net(state_tensor) return q_values.argmax().item() def learn(self): if len(self.memory) < BATCH_SIZE: return None states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE) # 현재 Q값 current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)) # 타겟 Q값 with torch.no_grad(): next_q = self.target_net(next_states).max(1)[0] target_q = rewards + GAMMA * next_q * (1 - dones) # 손실 계산 loss = nn.MSELoss()(current_q.squeeze(), target_q) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.optimizer.step() return loss.item() def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def update_epsilon(self): self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY) def train(): env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = DQNAgent(state_dim, action_dim) episode_rewards = [] for episode in range(EPISODES): state, _ = env.reset() total_reward = 0 while True: action = agent.select_action(state) next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated agent.memory.push(state, action, reward, next_state, float(done)) agent.learn() total_reward += reward state = next_state if done: break agent.update_epsilon() if episode % TARGET_UPDATE == 0: agent.update_target_network() episode_rewards.append(total_reward) if episode % 10 == 0: avg_reward = np.mean(episode_rewards[-100:]) print(f"Episode {episode}, Reward: {total_reward:.0f}, " f"Avg: {avg_reward:.1f}, Epsilon: {agent.epsilon:.3f}") env.close() return episode_rewards, agent if __name__ == "__main__": rewards, agent = train()
실행 방법
pip install gymnasium torch numpy python dqn_cartpole.py
6. 아타리 게임용 CNN 구조
이미지 입력을 처리하는 DQN:
class AtariDQN(nn.Module): def __init__(self, action_dim): super().__init__() # 입력: 84x84x4 (4개 프레임 스택) self.conv = nn.Sequential( nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU() ) self.fc = nn.Sequential( nn.Linear(64 * 7 * 7, 512), nn.ReLU(), nn.Linear(512, action_dim) ) def forward(self, x): x = x / 255.0 # 정규화 x = self.conv(x) x = x.view(x.size(0), -1) return self.fc(x)
7. 트러블슈팅
| 문제 | 원인 | 해결 |
|---|---|---|
| 학습이 안 됨 | 학습률 부적절 | 1e-4 ~ 1e-3 범위 조정 |
| 성능 하락 | ε 감소가 너무 빠름 | EPSILON_DECAY 낮추기 |
| Q값 폭발 | 타겟 업데이트 빈번 | TARGET_UPDATE 늘리기 |
| 메모리 부족 | 버퍼가 너무 큼 | MEMORY_SIZE 줄이기 |
8. 핵심 정리
| 구성 요소 | 역할 |
|---|---|
| 신경망 | Q(s,a)를 근사 |
| Experience Replay | 데이터 상관관계 제거, 재사용 |
| Target Network | 학습 타겟 안정화 |
| ε-greedy | 탐험과 활용 균형 |
Quiz
DQN에서 Experience Replay를 사용하는 주된 이유는?