Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarifying global floating point policy #357

Open
Sinacam opened this issue Oct 21, 2023 · 0 comments
Open

Clarifying global floating point policy #357

Sinacam opened this issue Oct 21, 2023 · 0 comments

Comments

@Sinacam
Copy link

Sinacam commented Oct 21, 2023

The issue of float precision affects many computations in tensorflow_ranking, such as

def _compute_impl(self, labels, predictions, weights, mask):
"""See `_RankingMetric`."""
topn = tf.shape(predictions)[1] if self._topn is None else self._topn
# Relevance = 1.0 when labels >= 1.0.
relevance = tf.cast(tf.greater_equal(labels, 1.0), dtype=tf.float32)
sorted_relevance, sorted_weights = utils.sort_by_scores(
predictions, [relevance, weights], topn=topn, mask=mask)
per_list_relevant_counts = tf.cumsum(sorted_relevance, axis=1)
per_list_cutoffs = tf.cumsum(tf.ones_like(sorted_relevance), axis=1)
per_list_precisions = tf.math.divide_no_nan(per_list_relevant_counts,
per_list_cutoffs)
total_precision = tf.reduce_sum(
input_tensor=per_list_precisions * sorted_weights * sorted_relevance,
axis=1,
keepdims=True)
# Compute the total relevance regardless of self._topn.
total_relevance = tf.reduce_sum(
input_tensor=weights * relevance, axis=1, keepdims=True)
per_list_map = tf.math.divide_no_nan(total_precision, total_relevance)
# per_list_weights are computed from the whole list to avoid the problem of
# 0 when there is no relevant example in topn.
per_list_weights = _per_example_weights_to_per_list_weights(
weights, relevance)
return per_list_map, per_list_weights

This has been mentioned before in #254, but I want to elaborate on our difficulties.
This type of hardcoded dtypes makes it extremely hard to move our programs to use float64.
For example, if we use tf.keras.backend.set_floatx('float64') anywhere, we get errors within tensorflow_ranking due to conflicting dtypes.

Will the global floating point policy (tf.keras.mixed_precision.set_global_policy and tf.keras.backend.floatx) be supported?
If the official stance on the global policy is to ignore it, can it be documented?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant