Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
275 views
in Technique[技术] by (71.8m points)

python 3.x - Keras, binary segmentation, add weight to loss function

I'm solving a binary segmentation problem with Keras (w. tf backend). How can I add more weight to the center of each area of mask?

I've tried dice coef with added cv2.erode(), but it doesn't work

def dice_coef_eroded(y_true, y_pred):
    kernel = (3, 3)
    y_true = cv2.erode(y_true.eval(), kernel, iterations=1)
    y_pred = cv2.erode(y_pred.eval(), kernel, iterations=1)
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)

Keras 2.1.3, tensorflow 1.4

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

All right, the solution I found is following:

1) Create in your Iterator a method to retrieve weights' matrix (with shape = mask shape). The output must contain [image, mask, weights]

2) Create a Lambda layer containing loss function

3) Create an Identity loss function

Example:

def weighted_binary_loss(X):
    import keras.backend as K
    import keras.layers.merge as merge
    y_pred, weights, y_true = X
    loss = K.binary_crossentropy(y_pred, y_true)
    loss = merge([loss, weights], mode='mul')
    return loss

def identity_loss(y_true, y_pred):
    return y_pred

def get_unet_w_lambda_loss(input_shape=(1024, 1024, 3), mask_shape=(1024, 1024, 1)):
    images = Input(input_shape)
    mask_weights = Input(mask_shape)
    true_masks = Input(mask_shape)
    ...
    y_pred = Conv2D(1, (1, 1), activation='sigmoid')(up1) #output of original unet
    loss = Lambda(weighted_binary_loss, output_shape=(1024, 1024, 1))([y_pred, mask_weights, true_masks])
    model = Model(inputs=[images, mask_weights, true_masks], outputs=loss)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...