본문 바로가기
의학 연구

[비전공자의 머신러닝 의학연구 #8] DeepSurv — Cox 회귀 + 신경망

by Dr CK 2026. 5. 10.

랜덤포레스트를 생존분석에 확장한것이 Random Survival Forest 였다면, 딥러닝 모델을 생존분석에 적용한 모델은 DeepSurv 이다. Katzman et al. 2018, BMC Medical Research Methodology (https://link.springer.com/article/10.1186/s12874-018-0482-1) 에서 소개된 모델로, Cox partial likelihood를 신경망의 손실 함수로 직접 사용 하는 것을 제시하여 생존분석 딥러닝의 시작점이 되는 모형이다.

 

 

1. DeepSurv의 핵심 — Cox β'X 를 신경망으로 대체

DeepSurv는 한마디로 Cox 비례위험 모형의 β'X 를 신경망으로 대체 한 것이다.

  Cox 회귀 DeepSurv
  h(t|X) = h₀(t) · exp(β₁X₁ + β₂X₂ + ... + βₚXₚ) h(t|X) = h₀(t) · exp(f(X))
위험 점수 β'X (선형) f(X) (MLP의 비선형)
비선형 학습 불가 (직접 다항·spline 추가) 자동 학습
상호작용 직접 추가 자동 학습

"Cox의 통계적 프레임워크 + 신경망의 표현력" 의 결합이라고 보면 된다.

 

2. 손실 함수 — Cox Partial Likelihood 그대로

DeepSurv의 결정적 특징 — 손실 함수는 Cox 부분우도 (partial likelihood) 를 그대로 사용 한다.

L = − Σᵢ [ ηᵢ − log Σⱼ∈R(tᵢ) exp(ηⱼ) ]

여기서 :

  • ηᵢ = MLP의 출력 (환자 i의 위험 점수)
  • R(tᵢ) = 시점 tᵢ 의 risk set (그때까지 살아있는 환자들)

이 손실 함수의 의미 :

  • 사건이 발생한 환자의 위험 점수가 높아지도록 학습
  • 검열(censoring) 자동 처리 — 사건 발생자만 우도에 포함
  • Cox와 동일한 통계적 프레임워크 — hazard ratio 해석 도 유지

 

3. 학습 흐름

  1. 환자 데이터 입력 — 시간-고정 변수 X + 생존 정보 (T, δ)
  2. MLP 통과 — η = f(X) 로 위험 점수 출력
  3. Cox 부분우도 손실 — 사건 발생자의 η가 높아지도록 손실 계산
  4. 가중치 업데이트 — Backpropagation + Adam 옵티마이저로 가중치 조정

이 과정을 손실이 줄어들 때까지 반복한다. 이전 #7편의 MLP 학습 4단계 와 같은 구조에 — 손실 함수만 Cox 부분우도로 바꾼 것이다.

 

4. 파이썬 코드 — PyCox 라이브러리

pycox 로 DeepSurv의 구현이 가능하며, pycoxPyTorch 위에서 작동하는 생존분석 전용 wrapper 라이브러리이다.

pip install pycox

4-1. 데이터 준비

import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from pycox.models import CoxPH
import torchtuples as tt

# 데이터 로드
df = pd.read_csv('survival_data.csv')

# 변수 (X) + 생존 정보 (T, δ)
features = ['age', 'tumor_size', 'T_stage', 'N_stage']
X = df[features].values.astype('float32')
T = df['time'].values.astype('float32')      # 추적 시간
E = df['event'].values.astype('float32')     # 사건 (1) / 검열 (0)

# 분할
X_tr, X_te, T_tr, T_te, E_tr, E_te = train_test_split(
    X, T, E, test_size=0.2, random_state=42
)

# 정규화 (딥러닝에선 필수)
scaler = StandardScaler()
X_tr = scaler.fit_transform(X_tr).astype('float32')
X_te = scaler.transform(X_te).astype('float32')

# pycox 형식으로 변환
y_tr = (T_tr, E_tr)
y_te = (T_te, E_te)

4-2. 모델 정의

import torch.nn as nn

class DeepSurvNet(nn.Module):
    def __init__(self, in_features, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)        # 출력: 위험 점수 1개
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

net = DeepSurvNet(in_features=X_tr.shape[1])

핵심 — 출력 차원이 1 (위험 점수 η 하나만). 일반 분류는 클래스 개수만큼이지만 DeepSurv는 항상 1 이다.

4-3. 학습

# pycox의 CoxPH 모델로 감싸기
model = CoxPH(net, optimizer=tt.optim.Adam(0.001))

# 학습
batch_size = 256
epochs = 100
log = model.fit(
    X_tr, y_tr,
    batch_size=batch_size,
    epochs=epochs,
    val_data=(X_te, y_te),
    verbose=True
)

# Baseline hazard 계산 (생존 곡선 그리려면 필수)
_ = model.compute_baseline_hazards()

CoxPH 가 내부적으로 Cox 부분우도 손실 을 적용해주어, 직접 손실 함수를 짤 필요 없다.

반응형