랜덤포레스트를 생존분석에 확장한것이 Random Survival Forest 였다면, 딥러닝 모델을 생존분석에 적용한 모델은 DeepSurv 이다. Katzman et al. 2018, BMC Medical Research Methodology (https://link.springer.com/article/10.1186/s12874-018-0482-1) 에서 소개된 모델로, Cox partial likelihood를 신경망의 손실 함수로 직접 사용 하는 것을 제시하여 생존분석 딥러닝의 시작점이 되는 모형이다.
1. DeepSurv의 핵심 — Cox β'X 를 신경망으로 대체

DeepSurv는 한마디로 Cox 비례위험 모형의 β'X 를 신경망으로 대체 한 것이다.
| Cox 회귀 | DeepSurv | |
| h(t|X) = h₀(t) · exp(β₁X₁ + β₂X₂ + ... + βₚXₚ) | h(t|X) = h₀(t) · exp(f(X)) | |
| 위험 점수 | β'X (선형) | f(X) (MLP의 비선형) |
| 비선형 학습 | 불가 (직접 다항·spline 추가) | 자동 학습 |
| 상호작용 | 직접 추가 | 자동 학습 |
"Cox의 통계적 프레임워크 + 신경망의 표현력" 의 결합이라고 보면 된다.
2. 손실 함수 — Cox Partial Likelihood 그대로
DeepSurv의 결정적 특징 — 손실 함수는 Cox 부분우도 (partial likelihood) 를 그대로 사용 한다.
L = − Σᵢ [ ηᵢ − log Σⱼ∈R(tᵢ) exp(ηⱼ) ]
여기서 :
- ηᵢ = MLP의 출력 (환자 i의 위험 점수)
- R(tᵢ) = 시점 tᵢ 의 risk set (그때까지 살아있는 환자들)
이 손실 함수의 의미 :
- 사건이 발생한 환자의 위험 점수가 높아지도록 학습
- 검열(censoring) 자동 처리 — 사건 발생자만 우도에 포함
- Cox와 동일한 통계적 프레임워크 — hazard ratio 해석 도 유지
3. 학습 흐름

- 환자 데이터 입력 — 시간-고정 변수 X + 생존 정보 (T, δ)
- MLP 통과 — η = f(X) 로 위험 점수 출력
- Cox 부분우도 손실 — 사건 발생자의 η가 높아지도록 손실 계산
- 가중치 업데이트 — Backpropagation + Adam 옵티마이저로 가중치 조정
이 과정을 손실이 줄어들 때까지 반복한다. 이전 #7편의 MLP 학습 4단계 와 같은 구조에 — 손실 함수만 Cox 부분우도로 바꾼 것이다.
4. 파이썬 코드 — PyCox 라이브러리
pycox 로 DeepSurv의 구현이 가능하며, pycox는 PyTorch 위에서 작동하는 생존분석 전용 wrapper 라이브러리이다.
pip install pycox
4-1. 데이터 준비
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from pycox.models import CoxPH
import torchtuples as tt
# 데이터 로드
df = pd.read_csv('survival_data.csv')
# 변수 (X) + 생존 정보 (T, δ)
features = ['age', 'tumor_size', 'T_stage', 'N_stage']
X = df[features].values.astype('float32')
T = df['time'].values.astype('float32') # 추적 시간
E = df['event'].values.astype('float32') # 사건 (1) / 검열 (0)
# 분할
X_tr, X_te, T_tr, T_te, E_tr, E_te = train_test_split(
X, T, E, test_size=0.2, random_state=42
)
# 정규화 (딥러닝에선 필수)
scaler = StandardScaler()
X_tr = scaler.fit_transform(X_tr).astype('float32')
X_te = scaler.transform(X_te).astype('float32')
# pycox 형식으로 변환
y_tr = (T_tr, E_tr)
y_te = (T_te, E_te)
4-2. 모델 정의
import torch.nn as nn
class DeepSurvNet(nn.Module):
def __init__(self, in_features, hidden_dim=64):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1) # 출력: 위험 점수 1개
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
net = DeepSurvNet(in_features=X_tr.shape[1])
핵심 — 출력 차원이 1 (위험 점수 η 하나만). 일반 분류는 클래스 개수만큼이지만 DeepSurv는 항상 1 이다.
4-3. 학습
# pycox의 CoxPH 모델로 감싸기
model = CoxPH(net, optimizer=tt.optim.Adam(0.001))
# 학습
batch_size = 256
epochs = 100
log = model.fit(
X_tr, y_tr,
batch_size=batch_size,
epochs=epochs,
val_data=(X_te, y_te),
verbose=True
)
# Baseline hazard 계산 (생존 곡선 그리려면 필수)
_ = model.compute_baseline_hazards()
CoxPH 가 내부적으로 Cox 부분우도 손실 을 적용해주어, 직접 손실 함수를 짤 필요 없다.
반응형
'의학 연구' 카테고리의 다른 글
| [비전공자의 머신러닝 의학연구 #7] 딥러닝과 MLP (0) | 2026.05.08 |
|---|---|
| [비전공자의 머신러닝 의학연구 #6] 부스팅과 XGBoost (0) | 2026.05.07 |
| [비전공자의 머신러닝 의학연구 #5] Random Forest, Random Survival Forest (0) | 2026.05.03 |
| [비전공자의 머신러닝 의학연구 #4] 결정트리 Decision Tree (0) | 2026.05.02 |
| [비전공자의 머신러닝 의학연구 #3] 머신러닝 학습에 알아야할 개념들 (0) | 2026.05.02 |