ryanthompson591 commented on code in PR #22924: URL: https://github.com/apache/beam/pull/22924#discussion_r961885728
########## sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py: ########## @@ -360,16 +360,14 @@ def __init__(self, self, data_plane.InMemoryDataChannel(), state, provision_info) self.control_conn = self # type: ignore # need Protocol to describe this self.data_conn = self.data_plane_handler - state_cache = StateCache(STATE_CACHE_SIZE) + state_cache = StateCache(STATE_CACHE_SIZE_MB << 20) Review Comment: I like the idea of using a constant here in some way like: STATE_CACHE_SIZE_MB << MB_TO_BYTES ########## sdks/python/apache_beam/runners/worker/statecache.py: ########## @@ -20,245 +20,180 @@ # mypy: disallow-untyped-defs import collections +import gc import logging import threading -from typing import TYPE_CHECKING from typing import Any -from typing import Callable -from typing import Generic -from typing import Hashable from typing import List from typing import Optional -from typing import Set from typing import Tuple -from typing import TypeVar -from apache_beam.metrics import monitoring_infos - -if TYPE_CHECKING: - from apache_beam.portability.api import metrics_pb2 +import objsize _LOGGER = logging.getLogger(__name__) -CallableT = TypeVar('CallableT', bound='Callable') -KT = TypeVar('KT') -VT = TypeVar('VT') +class WeightedValue(object): + """Value type that stores corresponding weight. -class Metrics(object): - """Metrics container for state cache metrics.""" + :arg value The value to be stored. + :arg weight The associated weight of the value. If unspecified, the objects + size will be used. + """ + def __init__(self, value, weight): + # type: (Any, int) -> None + self._value = value + if weight <= 0: + raise ValueError( + 'Expected weight to be > 0 for %s but received %d' % (value, weight)) + self._weight = weight + + def weight(self): + # type: () -> int + return self._weight - # A set of all registered metrics - ALL_METRICS = set() # type: Set[Hashable] - PREFIX = "beam:metric:statecache:" + def value(self): + # type: () -> Any + return self._value - def __init__(self): - # type: () -> None - self._context = threading.local() - def initialize(self): +class CacheAware(object): + def __init__(self): # type: () -> None + pass - """Needs to be called once per thread to initialize the local metrics cache. - """ - if hasattr(self._context, 'metrics'): - return # Already initialized - self._context.metrics = collections.defaultdict(int) - - def count(self, name): - # type: (str) -> None - self._context.metrics[name] += 1 - - def hit_miss(self, total_name, hit_miss_name): - # type: (str, str) -> None - self._context.metrics[total_name] += 1 - self._context.metrics[hit_miss_name] += 1 + def get_referents_for_cache(self): + # type: () -> List[Any] - def get_monitoring_infos(self, cache_size, cache_capacity): - # type: (int, int) -> List[metrics_pb2.MonitoringInfo] + """Returns the list of objects accounted during cache measurement.""" + raise NotImplementedError() - """Returns the metrics scoped to the current bundle.""" - metrics = self._context.metrics - if len(metrics) == 0: - # No metrics collected, do not report - return [] - # Add all missing metrics which were not reported - for key in Metrics.ALL_METRICS: - if key not in metrics: - metrics[key] = 0 - # Gauges which reflect the state since last queried - gauges = [ - monitoring_infos.int64_gauge(self.PREFIX + name, val) for name, - val in metrics.items() - ] - gauges.append( - monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size)) - gauges.append( - monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity)) - # Counters for the summary across all metrics - counters = [ - monitoring_infos.int64_counter(self.PREFIX + name + '_total', val) - for name, - val in metrics.items() - ] - # Reinitialize metrics for this thread/bundle - metrics.clear() - return gauges + counters - @staticmethod - def counter_hit_miss(total_name, hit_name, miss_name): - # type: (str, str, str) -> Callable[[CallableT], CallableT] +def get_referents_for_cache(*objs): + # type: (List[Any]) -> List[Any] - """Decorator for counting function calls and whether - the return value equals None (=miss) or not (=hit).""" - Metrics.ALL_METRICS.update([total_name, hit_name, miss_name]) + """Returns the list of objects accounted during cache measurement. - def decorator(function): - # type: (CallableT) -> CallableT - def reporter(self, *args, **kwargs): - # type: (StateCache, Any, Any) -> Any - value = function(self, *args, **kwargs) - if value is None: - self._metrics.hit_miss(total_name, miss_name) - else: - self._metrics.hit_miss(total_name, hit_name) - return value - - return reporter # type: ignore[return-value] - - return decorator - - @staticmethod - def counter(metric_name): - # type: (str) -> Callable[[CallableT], CallableT] - - """Decorator for counting function calls.""" - Metrics.ALL_METRICS.add(metric_name) - - def decorator(function): - # type: (CallableT) -> CallableT - def reporter(self, *args, **kwargs): - # type: (StateCache, Any, Any) -> Any - self._metrics.count(metric_name) - return function(self, *args, **kwargs) - - return reporter # type: ignore[return-value] - - return decorator + Users can inherit CacheAware to override which referrents should be + used when measuring the deep size of the object. The default is to + use gc.get_referents(*objs). + """ + # print(objs) + rval = [] + for obj in objs: + if isinstance(obj, CacheAware): + rval.extend(obj.get_referents_for_cache()) + else: + rval.extend(gc.get_referents(obj)) + return rval class StateCache(object): - """ Cache for Beam state access, scoped by state key and cache_token. - Assumes a bag state implementation. + """Cache for Beam state access, scoped by state key and cache_token. + Assumes a bag state implementation. - For a given state_key, caches a (cache_token, value) tuple and allows to + For a given state_key and cache_token, caches a value and allows to a) read from the cache (get), if the currently stored cache_token matches the provided - a) write to the cache (put), + b) write to the cache (put), storing the new value alongside with a cache token - c) append to the currently cache item (extend), - if the currently stored cache_token matches the provided c) empty a cached element (clear), if the currently stored cache_token matches the provided - d) evict a cached element (evict) + d) invalidate a cached element (invalidate) + e) invalidate all cached elements (invalidate_all) The operations on the cache are thread-safe for use by multiple workers. - :arg max_entries The maximum number of entries to store in the cache. - TODO Memory-based caching: https://github.com/apache/beam/issues/19857 + :arg max_weight The maximum weight of entries to store in the cache in bytes. """ - def __init__(self, max_entries): + def __init__(self, max_weight): # type: (int) -> None - _LOGGER.info('Creating state cache with size %s', max_entries) - self._missing = None - self._cache = self.LRUCache[Tuple[bytes, Optional[bytes]], - Any](max_entries, self._missing) + _LOGGER.info('Creating state cache with size %s', max_weight) + self._max_weight = max_weight + self._current_weight = 0 + self._cache = collections.OrderedDict( + ) # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue] + self._hit_count = 0 + self._miss_count = 0 + self._evict_count = 0 self._lock = threading.RLock() - self._metrics = Metrics() - @Metrics.counter_hit_miss("get", "hit", "miss") def get(self, state_key, cache_token): # type: (bytes, Optional[bytes]) -> Any assert cache_token and self.is_cache_enabled() + key = (state_key, cache_token) with self._lock: - return self._cache.get((state_key, cache_token)) + value = self._cache.get(key, None) + if value is None: + self._miss_count += 1 + return None + self._cache.move_to_end(key) + self._hit_count += 1 + return value.value() - @Metrics.counter("put") def put(self, state_key, cache_token, value): # type: (bytes, Optional[bytes], Any) -> None assert cache_token and self.is_cache_enabled() + if not isinstance(value, WeightedValue): + weight = objsize.get_deep_size( + value, get_referents_func=get_referents_for_cache) + if weight <= 0: + _LOGGER.warning( Review Comment: here this logs a warning and elsewhere it is an exception. What's up with that? ########## sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py: ########## @@ -653,7 +651,8 @@ def __init__(self, from apache_beam.transforms.environments import EmbeddedPythonGrpcEnvironment config = EmbeddedPythonGrpcEnvironment.parse_config(payload.decode('utf-8')) - self._state_cache_size = config.get('state_cache_size') or STATE_CACHE_SIZE + self._state_cache_size = ( + config.get('state_cache_size') or STATE_CACHE_SIZE_MB) << 20 Review Comment: Is it clear that state_cache_size in the config is in MB? ########## sdks/python/apache_beam/runners/worker/statecache_test.py: ########## @@ -21,209 +21,157 @@ import logging import unittest -from apache_beam.metrics import monitoring_infos +from apache_beam.runners.worker.statecache import CacheAware from apache_beam.runners.worker.statecache import StateCache +from apache_beam.runners.worker.statecache import WeightedValue class StateCacheTest(unittest.TestCase): def test_empty_cache_get(self): - cache = self.get_cache(5) + cache = StateCache(5 << 20) self.assertEqual(cache.get("key", 'cache_token'), None) with self.assertRaises(Exception): # Invalid cache token provided self.assertEqual(cache.get("key", None), None) - self.verify_metrics( - cache, - { - 'get': 1, - 'put': 0, - 'miss': 1, - 'hit': 0, - 'clear': 0, - 'evict': 0, - 'size': 0, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0') def test_put_get(self): - cache = self.get_cache(5) - cache.put("key", "cache_token", "value") + cache = StateCache(5 << 20) + cache.put("key", "cache_token", WeightedValue("value", 1 << 20)) self.assertEqual(cache.size(), 1) self.assertEqual(cache.get("key", "cache_token"), "value") self.assertEqual(cache.get("key", "cache_token2"), None) with self.assertRaises(Exception): self.assertEqual(cache.get("key", None), None) - self.verify_metrics( - cache, - { - 'get': 2, - 'put': 1, - 'miss': 1, - 'hit': 1, - 'clear': 0, - 'evict': 0, - 'size': 1, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0') def test_clear(self): - cache = self.get_cache(5) + cache = StateCache(5 << 20) cache.clear("new-key", "cache_token") - cache.put("key", "cache_token", ["value"]) + cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20)) Review Comment: I might be confused. But isn't a weighted value something different than a regular value? Would it make sense for these tests to stick with regular values and then make a new set of tests for weighted values? ########## sdks/python/apache_beam/runners/worker/statecache.py: ########## @@ -20,245 +20,180 @@ # mypy: disallow-untyped-defs import collections +import gc import logging import threading -from typing import TYPE_CHECKING from typing import Any -from typing import Callable -from typing import Generic -from typing import Hashable from typing import List from typing import Optional -from typing import Set from typing import Tuple -from typing import TypeVar -from apache_beam.metrics import monitoring_infos - -if TYPE_CHECKING: - from apache_beam.portability.api import metrics_pb2 +import objsize _LOGGER = logging.getLogger(__name__) -CallableT = TypeVar('CallableT', bound='Callable') -KT = TypeVar('KT') -VT = TypeVar('VT') +class WeightedValue(object): + """Value type that stores corresponding weight. -class Metrics(object): - """Metrics container for state cache metrics.""" + :arg value The value to be stored. + :arg weight The associated weight of the value. If unspecified, the objects + size will be used. + """ + def __init__(self, value, weight): + # type: (Any, int) -> None + self._value = value + if weight <= 0: + raise ValueError( + 'Expected weight to be > 0 for %s but received %d' % (value, weight)) + self._weight = weight + + def weight(self): + # type: () -> int + return self._weight - # A set of all registered metrics - ALL_METRICS = set() # type: Set[Hashable] - PREFIX = "beam:metric:statecache:" + def value(self): + # type: () -> Any + return self._value - def __init__(self): - # type: () -> None - self._context = threading.local() - def initialize(self): +class CacheAware(object): + def __init__(self): # type: () -> None + pass - """Needs to be called once per thread to initialize the local metrics cache. - """ - if hasattr(self._context, 'metrics'): - return # Already initialized - self._context.metrics = collections.defaultdict(int) - - def count(self, name): - # type: (str) -> None - self._context.metrics[name] += 1 - - def hit_miss(self, total_name, hit_miss_name): - # type: (str, str) -> None - self._context.metrics[total_name] += 1 - self._context.metrics[hit_miss_name] += 1 + def get_referents_for_cache(self): + # type: () -> List[Any] - def get_monitoring_infos(self, cache_size, cache_capacity): - # type: (int, int) -> List[metrics_pb2.MonitoringInfo] + """Returns the list of objects accounted during cache measurement.""" + raise NotImplementedError() - """Returns the metrics scoped to the current bundle.""" - metrics = self._context.metrics - if len(metrics) == 0: - # No metrics collected, do not report - return [] - # Add all missing metrics which were not reported - for key in Metrics.ALL_METRICS: - if key not in metrics: - metrics[key] = 0 - # Gauges which reflect the state since last queried - gauges = [ - monitoring_infos.int64_gauge(self.PREFIX + name, val) for name, - val in metrics.items() - ] - gauges.append( - monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size)) - gauges.append( - monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity)) - # Counters for the summary across all metrics - counters = [ - monitoring_infos.int64_counter(self.PREFIX + name + '_total', val) - for name, - val in metrics.items() - ] - # Reinitialize metrics for this thread/bundle - metrics.clear() - return gauges + counters - @staticmethod - def counter_hit_miss(total_name, hit_name, miss_name): - # type: (str, str, str) -> Callable[[CallableT], CallableT] +def get_referents_for_cache(*objs): + # type: (List[Any]) -> List[Any] - """Decorator for counting function calls and whether - the return value equals None (=miss) or not (=hit).""" - Metrics.ALL_METRICS.update([total_name, hit_name, miss_name]) + """Returns the list of objects accounted during cache measurement. - def decorator(function): - # type: (CallableT) -> CallableT - def reporter(self, *args, **kwargs): - # type: (StateCache, Any, Any) -> Any - value = function(self, *args, **kwargs) - if value is None: - self._metrics.hit_miss(total_name, miss_name) - else: - self._metrics.hit_miss(total_name, hit_name) - return value - - return reporter # type: ignore[return-value] - - return decorator - - @staticmethod - def counter(metric_name): - # type: (str) -> Callable[[CallableT], CallableT] - - """Decorator for counting function calls.""" - Metrics.ALL_METRICS.add(metric_name) - - def decorator(function): - # type: (CallableT) -> CallableT - def reporter(self, *args, **kwargs): - # type: (StateCache, Any, Any) -> Any - self._metrics.count(metric_name) - return function(self, *args, **kwargs) - - return reporter # type: ignore[return-value] - - return decorator + Users can inherit CacheAware to override which referrents should be + used when measuring the deep size of the object. The default is to + use gc.get_referents(*objs). + """ + # print(objs) Review Comment: remove this line? ########## sdks/python/apache_beam/runners/worker/statecache_test.py: ########## @@ -21,209 +21,157 @@ import logging import unittest -from apache_beam.metrics import monitoring_infos +from apache_beam.runners.worker.statecache import CacheAware from apache_beam.runners.worker.statecache import StateCache +from apache_beam.runners.worker.statecache import WeightedValue class StateCacheTest(unittest.TestCase): def test_empty_cache_get(self): - cache = self.get_cache(5) + cache = StateCache(5 << 20) self.assertEqual(cache.get("key", 'cache_token'), None) with self.assertRaises(Exception): # Invalid cache token provided self.assertEqual(cache.get("key", None), None) - self.verify_metrics( - cache, - { - 'get': 1, - 'put': 0, - 'miss': 1, - 'hit': 0, - 'clear': 0, - 'evict': 0, - 'size': 0, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0') def test_put_get(self): - cache = self.get_cache(5) - cache.put("key", "cache_token", "value") + cache = StateCache(5 << 20) + cache.put("key", "cache_token", WeightedValue("value", 1 << 20)) self.assertEqual(cache.size(), 1) self.assertEqual(cache.get("key", "cache_token"), "value") self.assertEqual(cache.get("key", "cache_token2"), None) with self.assertRaises(Exception): self.assertEqual(cache.get("key", None), None) - self.verify_metrics( - cache, - { - 'get': 2, - 'put': 1, - 'miss': 1, - 'hit': 1, - 'clear': 0, - 'evict': 0, - 'size': 1, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0') def test_clear(self): - cache = self.get_cache(5) + cache = StateCache(5 << 20) cache.clear("new-key", "cache_token") - cache.put("key", "cache_token", ["value"]) + cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20)) self.assertEqual(cache.size(), 2) self.assertEqual(cache.get("new-key", "new_token"), None) self.assertEqual(cache.get("key", "cache_token"), ['value']) # test clear without existing key/token cache.clear("non-existing", "token") self.assertEqual(cache.size(), 3) self.assertEqual(cache.get("non-existing", "token"), []) - self.verify_metrics( - cache, - { - 'get': 3, - 'put': 1, - 'miss': 1, - 'hit': 2, - 'clear': 2, - 'evict': 0, - 'size': 3, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0') + + def test_default_sized_put(self): + cache = StateCache(5 << 20) + cache.put("key", "cache_token", bytearray(1 << 20)) + cache.put("key2", "cache_token", bytearray(1 << 20)) + cache.put("key3", "cache_token", bytearray(1 << 20)) + self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20)) + cache.put("key4", "cache_token", bytearray(1 << 20)) + cache.put("key5", "cache_token", bytearray(1 << 20)) + # note that each byte array instance takes slightly over 1 MB which is why + # these 5 byte arrays can't all be stored in the cache causing a single + # eviction + self.assertEqual( + cache.describe_stats(), + 'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1') def test_max_size(self): - cache = self.get_cache(2) - cache.put("key", "cache_token", "value") - cache.put("key2", "cache_token", "value") - self.assertEqual(cache.size(), 2) - cache.put("key2", "cache_token", "value") + cache = StateCache(2 << 20) + cache.put("key", "cache_token", WeightedValue("value", 1 << 20)) + cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20)) self.assertEqual(cache.size(), 2) - cache.put("key", "cache_token", "value") + cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20)) self.assertEqual(cache.size(), 2) - self.verify_metrics( - cache, - { - 'get': 0, - 'put': 4, - 'miss': 0, - 'hit': 0, - 'clear': 0, - 'evict': 0, - 'size': 2, - 'capacity': 2 - }) - - def test_evict_all(self): - cache = self.get_cache(5) - cache.put("key", "cache_token", "value") - cache.put("key2", "cache_token", "value2") + self.assertEqual( + cache.describe_stats(), + 'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1') + + def test_invalidate_all(self): + cache = StateCache(5 << 20) + cache.put("key", "cache_token", WeightedValue("value", 1 << 20)) + cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20)) self.assertEqual(cache.size(), 2) - cache.evict_all() + cache.invalidate_all() self.assertEqual(cache.size(), 0) self.assertEqual(cache.get("key", "cache_token"), None) self.assertEqual(cache.get("key2", "cache_token"), None) - self.verify_metrics( - cache, - { - 'get': 2, - 'put': 2, - 'miss': 2, - 'hit': 0, - 'clear': 0, - 'evict': 0, - 'size': 0, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0') def test_lru(self): - cache = self.get_cache(5) - cache.put("key", "cache_token", "value") - cache.put("key2", "cache_token2", "value2") - cache.put("key3", "cache_token", "value0") - cache.put("key3", "cache_token", "value3") - cache.put("key4", "cache_token4", "value4") - cache.put("key5", "cache_token", "value0") - cache.put("key5", "cache_token", ["value5"]) + cache = StateCache(5 << 20) + cache.put("key", "cache_token", WeightedValue("value", 1 << 20)) + cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20)) + cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20)) + cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20)) + cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20)) + cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20)) + cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20)) self.assertEqual(cache.size(), 5) self.assertEqual(cache.get("key", "cache_token"), "value") self.assertEqual(cache.get("key2", "cache_token2"), "value2") self.assertEqual(cache.get("key3", "cache_token"), "value3") self.assertEqual(cache.get("key4", "cache_token4"), "value4") self.assertEqual(cache.get("key5", "cache_token"), ["value5"]) # insert another key to trigger cache eviction - cache.put("key6", "cache_token2", "value7") + cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20)) self.assertEqual(cache.size(), 5) # least recently used key should be gone ("key") self.assertEqual(cache.get("key", "cache_token"), None) # trigger a read on "key2" cache.get("key2", "cache_token2") # insert another key to trigger cache eviction - cache.put("key7", "cache_token", "value7") + cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20)) self.assertEqual(cache.size(), 5) # least recently used key should be gone ("key3") self.assertEqual(cache.get("key3", "cache_token"), None) # trigger a put on "key2" - cache.put("key2", "cache_token", "put") + cache.put("key2", "cache_token", WeightedValue("put", 1 << 20)) self.assertEqual(cache.size(), 5) # insert another key to trigger cache eviction - cache.put("key8", "cache_token", "value8") + cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20)) self.assertEqual(cache.size(), 5) # least recently used key should be gone ("key4") self.assertEqual(cache.get("key4", "cache_token"), None) # make "key5" used by writing to it - cache.put("key5", "cache_token", "val") + cache.put("key5", "cache_token", WeightedValue("val", 1 << 20)) # least recently used key should be gone ("key6") self.assertEqual(cache.get("key6", "cache_token"), None) - self.verify_metrics( - cache, - { - 'get': 10, - 'put': 12, - 'miss': 4, - 'hit': 6, - 'clear': 0, - 'evict': 0, - 'size': 5, - 'capacity': 5 - }) + self.assertEqual( + cache.describe_stats(), + 'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5') def test_is_cached_enabled(self): - cache = self.get_cache(1) + cache = StateCache(1 << 20) self.assertEqual(cache.is_cache_enabled(), True) - self.verify_metrics(cache, {}) - cache = self.get_cache(0) + self.assertEqual( + cache.describe_stats(), + 'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0') + cache = StateCache(0) self.assertEqual(cache.is_cache_enabled(), False) - self.verify_metrics(cache, {}) - - def verify_metrics(self, cache, expected_metrics): - infos = cache.get_monitoring_infos() - # Reconstruct metrics dictionary from monitoring infos - metrics = { - info.urn.rsplit(':', - 1)[1]: monitoring_infos.extract_gauge_value(info)[1] - for info in infos if "_total" not in info.urn and - info.type == monitoring_infos.LATEST_INT64_TYPE - } - self.assertDictEqual(metrics, expected_metrics) - # Metrics and total metrics should be identical for a single bundle. - # The following two gauges are not part of the total metrics: - try: - del metrics['capacity'] - del metrics['size'] - except KeyError: - pass - total_metrics = { - info.urn.rsplit(':', 1)[1].rsplit("_total")[0]: - monitoring_infos.extract_counter_value(info) - for info in infos - if "_total" in info.urn and info.type == monitoring_infos.SUM_INT64_TYPE - } - self.assertDictEqual(metrics, total_metrics) - - @staticmethod - def get_cache(size): - cache = StateCache(size) - cache.initialize_metrics() - return cache + self.assertEqual( + cache.describe_stats(), + 'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0') + + def test_get_referents_for_cache(self): + class GetReferentsForCache(CacheAware): + def __init__(self): + self.key = bytearray(1 << 20) + self.value = bytearray(2 << 20) + + def get_referents_for_cache(self): + return [self.key] + + cache = StateCache(5 << 20) + cache.put("key", "cache_token", GetReferentsForCache()) Review Comment: My first reading of this was a little hard but I think it makes sense. If I don't have this right, maybe add a comment so it will make sense. CacheAware allows overriding get_references_for_cache to return a list of objects. It's not clear to me why we return the key in this test instead of the key/value. then we're just testing that we measure the size of the referents as returned in this object right? ########## sdks/python/apache_beam/runners/portability/flink_runner_test.py: ########## @@ -296,95 +292,6 @@ def test_flattened_side_input(self): def test_metrics(self): super().test_metrics(check_gauge=False) - def test_flink_metrics(self): Review Comment: why does this PR involve removing the metrics test? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org