mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Return a classifier score of the same type as the logits.
PiperOrigin-RevId: 174184871
This commit is contained in:
committed by
TensorFlower Gardener
parent
9da02be116
commit
18bf5b2d91
@@ -297,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||
efficiently run them through the classifier network.
|
||||
|
||||
Returns:
|
||||
The classifier score. A floating-point scalar.
|
||||
The classifier score. A floating-point scalar of the same type as the output
|
||||
of `classifier_fn`.
|
||||
"""
|
||||
generated_images_list = array_ops.split(
|
||||
images, num_or_size_splits=num_batches)
|
||||
@@ -316,7 +317,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||
# Use maximum precision for best results.
|
||||
logits_dtype = logits.dtype
|
||||
if logits_dtype != dtypes.float64:
|
||||
logits = math_ops.cast(logits, dtypes.float64)
|
||||
logits = math_ops.to_double(logits)
|
||||
|
||||
p = nn_ops.softmax(logits)
|
||||
q = math_ops.reduce_mean(p, axis=0)
|
||||
@@ -326,7 +327,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||
final_score = math_ops.exp(log_score)
|
||||
|
||||
if logits_dtype != dtypes.float64:
|
||||
final_score = math_ops.cast(final_score, dtypes.float64)
|
||||
final_score = math_ops.cast(final_score, logits_dtype)
|
||||
return final_score
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user