## ---------------------------------------------------------
## import libraries
## ---------------------------------------------------------
using CSV
using DataFrames
using LinearAlgebra
using StatsModels
using MLPreprocessing
using StatsBase
using Turing
using StatsFuns: logistic, softplus
using Plots
using Distributions # for Laplace
using KernelDensity # for kernel density estimation
using Plots: Shape # to plot polygons
import Random: AbstractRNG
## ---------------------------------------------------------
## Preprocessing
## ---------------------------------------------------------
# training data
df_train = CSV.read("DILI_raw_data.txt")
# rename columns
colnames = string.(names(df_train))
colnames = map(s -> replace(s, "." => "_"), colnames)
names!(df_train, Symbol.(colnames));
# select predictors
drug_names = df_train[!, :Drug]
df_train = df_train[!, [:Spher, :BSEP, :THP1, :Glu, :Glu_Gal, :ClogP, :log10_cmax, :BA, :dili_sev]];
# additional data
df_test = DataFrame(
AZID = ["AZD123", "AZD456"],
Spher = [250, 250],
BSEP = [314, 1000],
THP1 = [250, 250],
Glu = [250, 250],
Glu_Gal = [1, 1],
ClogP = [2.501, 3.900],
BA = [0, 0],
log10_cmax = [0.5051500, 0.7075702])
# select variables for scaling
df_Xtrain = df_train[!, [:Spher, :BSEP, :THP1, :Glu, :Glu_Gal, :ClogP, :log10_cmax]]
df_Xtest = df_test[!, [:Spher,:BSEP, :THP1, :Glu, :Glu_Gal, :ClogP, :log10_cmax]]
# standardize train and test data
scaler = fit(StandardScaler, df_Xtrain)
transform!(df_Xtrain, scaler)
transform!(df_Xtest, scaler)
# return Bioactivation (it did not require standardisation)
df_Xtrain[!, :BA] = df_train[!, :BA]
df_Xtest[!, :BA] = df_test[!, :BA]
# combine X and y in one column
df_yXtrain = hcat(df_train[!, :dili_sev], df_Xtrain);
f = @formula(x1 ~ 0 + Spher + BSEP + THP1 + Glu + Glu_Gal + ClogP + BA + log10_cmax +
Spher * (BSEP + THP1 + Glu + Glu_Gal + ClogP + BA) +
BSEP * (THP1 + Glu + Glu_Gal + ClogP + BA) +
THP1 * (Glu + Glu_Gal + ClogP + BA) +
Glu * (Glu_Gal + ClogP + BA)+
Glu_Gal * (ClogP + BA)+
ClogP* BA);
# model frame
mf = ModelFrame(f, df_yXtrain)
coefnames(mf)
mm = modelmatrix(mf);
y = convert(Array, df_train[!, :dili_sev]);
X = convert(Matrix, mm);
println(countmap(y))
println(typeof(X))
println(typeof(y))
struct OrderedLogistic{T1, T2<:AbstractVector} <: DiscreteUnivariateDistribution
η::T1
cutpoints::T2
function OrderedLogistic(η, cutpoints)
if !issorted(cutpoints)
error("cutpoints are not sorted")
end
return new{typeof(η), typeof(cutpoints)}(η, cutpoints)
end
end
function Distributions.logpdf(d::OrderedLogistic, k::Int)
K = length(d.cutpoints)+1
c = d.cutpoints
if k==1
logp= - softplus(-(c[k]-d.η)) #logp= log(logistic(c[k]-d.η))
elseif k<K
logp= log(logistic(c[k]-d.η) - logistic(c[k-1]-d.η))
else
logp= - softplus(c[k-1]-d.η) #logp= log(1-logistic(c[k-1]-d.η))
end
return logp
end
function Distributions.rand(rng::AbstractRNG, d::OrderedLogistic)
cutpoints = d.cutpoints
η = d.η
K = length(cutpoints)+1
c = vcat(-Inf, cutpoints, Inf)
ps = [logistic(η - i[1]) - logistic(η - i[2]) for i in zip(c[1:(end-1)],c[2:end])]
k = rand(rng, Categorical(ps))
if all(ps.>0)
return(k)
else
return(-Inf)
end
end
## ---------------------------------------------------------
## Define and fit model
## ---------------------------------------------------------
sigma_prior = 1
### Turing model
@model m(X, y, ::Type{VT}=Vector{Float64}) where {VT} = begin
p = size(X, 2)
# priors
mu ~ Normal(0, 2)
sigma ~ TruncatedNormal(0, sigma_prior, 0,Inf)
beta = VT(undef, p)
for i=1:p
beta[i] ~Â Laplace(mu , sigma)
end
c1 ~ Normal(0, 20)
log_diff_c ~ Normal(0, 2)
c2 = c1 + exp(log_diff_c)
c = [c1, c2]
eta = X * beta
# likelihood
for i = 1:length(y)
y[i] ~ OrderedLogistic(eta[i], c)
end
end
# sampling
steps = 10000
chain = sample(m(X, y), NUTS(steps, 0.65));
#show(chain)
## ---------------------------------------------------------
## Helper functions
## ---------------------------------------------------------
# posterior prediction for each category
function ps(y_pred_samps, ind)
y_pred_ind = y_pred_samps[ind,:]
p = [mean(y_pred_ind .== 1), mean(y_pred_ind .== 2), mean(y_pred_ind .== 3)]
return p
end
# percent of distribution in each category (blue densities)
function percs(post, ind)
post_ind = post[ind, :]
prs = [mean(post_ind .<= c1), mean((post_ind .> c1) .* (post_ind .<= c2)), mean(post_ind .> c2)]
return prs
end
function summary_stats(post_ind)
post_ind_01 = vcat(0, post_ind, 1)
dens = kde(post_ind_01, npoints=512, bandwidth=0.01)
dens_y = dens.density
dens_x = dens.x
# there must be a better way to find the position of the maximal element
pos = sum((dens_y .== maximum(dens_y)) .* range(1, stop=512, step=1))
post_peak = round(dens_x[pos], digits=2)
post_mean = round(mean(post_ind), digits = 2)
post_median = round(median(post_ind), digits = 2)
post_q025 = round(quantile(post_ind, 0.025), digits = 2)
post_q975 = round(quantile(post_ind, 0.975), digits = 2)
return post_peak, post_mean, post_median, post_q025, post_q975
end
## ---------------------------------------------------------
## predict for training data
## ---------------------------------------------------------
## extract posterior samples
# fixed effects
beta_est = chain[:beta].value.data[:,:,1]';
eta_post = X * beta_est;
# cutpoints
e_log_diff_c = exp.(chain[:log_diff_c].value.data)[:,1,1]
c1_est = chain[:c1].value.data[:,1,1]
c2_est = c1_est + e_log_diff_c;
y_pred_samps = zeros(size(eta_post));
y_pred = zeros(size(eta_post, 1));
for i in 1:size(y_pred_samps,1)
for j in 1:size(y_pred_samps,2)
c1 = c1_est[j,1,1]
c2 = c2_est[j,1,1]
c = [c1, c2]
dist = OrderedLogistic(eta_post[i,j], c)
y_pred_samps[i,j] = rand(dist)
end
probs = [mean(y_pred_samps[i,:] .== 1), mean(y_pred_samps[i,:] .== 2), mean(y_pred_samps[i,:] .== 3)]
y_pred[i] = sum((probs .== maximum(probs)) .* [1, 2, 3])
end
# accuracy (it is higher in Turing with normal priors for beta)
println("Accuracy: ", round(mean(y_pred .== y), digits=2))
y_bin = zeros(length(y))
y_pred_bin = zeros(length(y))
for i in 1:length(y)
y_bin[i] = (y[i] == 1 ? 0 : 1)
y_pred_bin[i] = (y_pred[i] == 1 ? 0 : 1)
end
# binary accuracy
println("Binary accuracy: ", round(mean(y_pred_bin .== y_bin), digits=2))
## ---------------------------------------------------------
## Plot results
## ---------------------------------------------------------
## extract predicted values and convert to 0-1 scale
post = logistic.(eta_post);
# cutpoints
c1 = logistic(mean(c1_est))
c2 = logistic(mean(c2_est))
## calculate average profile for DILI category 3 compounds
y_3 = findall(x->x==3, y)
kde_npoints = 2048
av3 = zeros(kde_npoints, length(y_3));
counter = 1
for ind in y_3
post_ind = post[ind, :]
post_ind_01 = vcat(0, post[ind, :], 1)
dens = kde(post_ind_01, npoints=kde_npoints, bandwidth=0.01)
av3[:,counter] = dens.density
counter += 1
end
post_ind = post[y_3[1], :]
post_ind_01 = vcat(0, post[y_3[1], :], 1)
dens = kde(post_ind_01, npoints=kde_npoints, bandwidth=0.01)
av3_x = dens.x
av3_y = mean(av3, dims=2);
## ---------------------------------------------------------
## Plotting function
## ---------------------------------------------------------
function postplot(post, ind)
post_ind = post[ind,:]
kde(post_ind)
p1 = plot(kde(post_ind).x,kde(post_ind).density,
fill = (0, 0.2, :blue),
title = drug_names[ind],
xlabel="P(DILI)",
xlims = (0,1.01),
ylims = (0,12),
legend=false,
yticks = false,
framestyle = :box)
vline!([c1, c2], color = :black, linestyle = :dash)
# define a function that returns a Plots.Shape
rectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h])
plot!(rectangle(c1+0.05, 0.5,-0.05,0), color = :green, alpha = 0.9)
plot!(rectangle(c2-c1,0.5,c1,0), color = :darkgoldenrod, alpha = 0.9)
plot!(rectangle(1.05-c2,0.5,c2,0), color = :firebrick, alpha = 0.9)
pers = percs(post, ind)
s1 = string(convert(Int64, round(pers[1] * 100)), "%")
s2 = string(convert(Int64, round(pers[2] * 100)), "%")
s3 = string(convert(Int64, round(pers[3] * 100)), "%")
annotate!(c1/2, 0.2, text(s1, :white, 8))
annotate!(c1+(c2-c1)/2, 0.2, text(s2, :white, 8))
annotate!(c2+0.04, 0.2, text(s3, :white, 8))
plot!(av3_x, av3_y, color=:black)
# ======= p2 =======================
# ========= legend 1 ===============
p2 = plot( ylim=(0,12), xlim=(0,10), border=:none)
#p2 = plot( ylim=(0,12), xlim=(0,10))
x_title = 2.3
y_title = 11.5
w = 0.5
h = 0.3
x_green = 1.5
y_green = y_title - 0.7
x_golden = 1.5
y_golden = y_title - 1.2
x_red = 1.5
y_red = y_title - 1.7
x_box = 0.5
y_box = 9.5
w_box = 4.4
h_box = 2.5
probs = ps(y_pred_samps, ind)
s1 = string(round(probs[1], digits=2), "0" ^ (4-length(string(round(probs[1], digits=2)))))
s2 = string(round(probs[2], digits=2), "0" ^ (4-length(string(round(probs[2], digits=2)))))
s3 = string(round(probs[3], digits=2), "0" ^ (4-length(string(round(probs[3], digits=2)))))
plot!(Shape(x_box .+ [0,w_box,w_box,0], y_box .+ [0,0,h_box,h_box]), color = :white, alpha = 0.8, legend=false);
annotate!(x_title, y_title, text(" Proportions ", :balck, 8))
plot!(Shape(x_green .+ [0,w,w,0], y_green .+ [0,0,h,h]), color = :green, alpha = 0.8, legend=false)
plot!(Shape(x_golden .+ [0,w,w,0], y_golden .+ [0,0,h,h]), color = :darkgoldenrod, alpha = 0.8, legend=false)
plot!(Shape(x_red .+ [0,w,w,0], y_red .+ [0,0,h,h]), color = :firebrick, alpha = 0.8, legend=false);
annotate!(x_green+ 1.1, y_green+0.1, text(s1, 7))
annotate!(x_green+ 1.1, y_golden+0.1, text(s2, 7))
annotate!(x_green+ 1.1, y_red+0.1, text(s3, 7))
# ========= legend 2 ===============
x_box = 0.5
y_box = 7.0
w_box = 4.4
h_box = 1.5
x_title = 2.3
y_title = 8
x_green = 1.5
y_green = y_title - 0.7
plot!(Shape(x_box .+ [0,w_box,w_box,0], y_box .+ [0,0,h_box,h_box]), color = :white, alpha = 0.8, legend=false);
annotate!(x_title, y_title, text("True category", :balck, 8))
if y[ind]==1
plot!(Shape(x_green .+ [0,w,w,0], y_green .+ [0,0,h,h]), color = :green, alpha = 0.8, legend=false)
elseif y[ind]==2
plot!(Shape(x_green .+ [0,w,w,0], y_green .+ [0,0,h,h]), color = :darkgoldenrod, alpha = 0.8, legend=false)
else
plot!(Shape(x_green .+ [0,w,w,0], y_green .+ [0,0,h,h]), color = :firebrick, alpha = 0.8, legend=false)
end
annotate!(x_green+ 0.8, y_green+0.1, text(string(y[ind]), 7))
# ======= summary stats =================
res = summary_stats(post_ind)
x_box = 0.5
y_box = 2.7
w_box = 4.4
h_box = 3.2
x_title = 1.0
y_title = 4.3
plot!(Shape(x_box .+ [0,w_box,w_box,0], y_box .+ [0,0,h_box,h_box]), color = :white, alpha = 0.8, legend=false);
s = string(" Summary stats \n \n Peak = ", res[1], " \n Mean = ", res[2], " \n Median = ", res[3], " \n 95% CI = ", res[4], "-", res[5])
#s = " Summary stats \n \n Peak = ", res[1], " \n Mean = 2 \n Median = 2 \n 95% CI = 2"
#annotate!(x_title, y_title, text(" Summary stats \n \n Peak = ", res[1], " \n Mean = 2 \n Median = 2 \n 95% CI = 2",
#:balck, :left, 8))
annotate!(x_title, y_title, text(s,
:balck, :left, 8))
#annotate!(x_title, y_title - 1.5, text(" Peak = 2 \n Mean = 2 \n Median = 2 \n 95% CI = 2", :balck, :left, 8))
# ======= display =======================
p = plot(p1,p2,layout=(1,2),legend=false)
return p
end
for i=1:length(y)
#for i=1:1
p = postplot(post, i)
display(p)
end
## ---------------------------------------------------------
## Predict for test data
## ---------------------------------------------------------
# combine X and y in one column: add dummy y
ytest = DataFrame(x1 = [0,0])
df_yXtest = hcat(ytest, df_Xtest)
# model frame
mf_test = ModelFrame(f, df_yXtest)
mm_test = modelmatrix(mf_test)
X_test = convert(Matrix, mm_test)
eta_test = X_test * beta_est
y_test_samps = zeros(size(eta_test))
y_test_pred = zeros(size(y_test_samps, 1));
for i in 1:size(y_test_samps,1)
for j in 1:size(y_test_samps,2)
c1 = c1_est[j,1,1]
c2 = c2_est[j,1,1]
c = [c1, c2]
dist = OrderedLogistic(eta_test[i,j], c)
y_test_samps[i,j] = rand(dist)
end
probs = [mean(y_test_samps[i,:] .== 1), mean(y_test_samps[i,:] .== 2), mean(y_test_samps[i,:] .== 3)]
println(probs)
y_test_pred[i] = sum((probs .== maximum(probs)) .* [1, 2, 3])
end