Skip to content

Commit

Permalink
[CHIA-387] DL batch upsert optimization. (#17999)
Browse files Browse the repository at this point in the history
* DL batch upsert optimization.

* Lint

* Lint

* Fix test.

* Convert delete/insert to upserts.

* Update data_store.py

* Improve coverage.

* Whitespace.

* Change test to use upsert too.

* Clarify test usage.
  • Loading branch information
fchirica committed May 16, 2024
1 parent d868d9c commit b7e8fe7
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 18 deletions.
37 changes: 29 additions & 8 deletions chia/_tests/core/data_layer/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from chia._tests.core.data_layer.util import Example, add_0123_example, add_01234567_example
from chia._tests.util.misc import BenchmarkRunner, Marks, datacases
from chia._tests.util.misc import BenchmarkRunner, Marks, boolean_datacases, datacases
from chia.data_layer.data_layer_errors import KeyNotFoundError, NodeHashError, TreeGenerationIncrementingError
from chia.data_layer.data_layer_util import (
DiffData,
Expand Down Expand Up @@ -1991,20 +1991,41 @@ async def test_insert_key_already_present(data_store: DataStore, store_id: bytes


@pytest.mark.anyio
async def test_update_keys(data_store: DataStore, store_id: bytes32) -> None:
@boolean_datacases(name="use_batch_autoinsert", false="not optimized batch insert", true="optimized batch insert")
async def test_batch_insert_key_already_present(
data_store: DataStore,
store_id: bytes32,
use_batch_autoinsert: bool,
) -> None:
key = b"foo"
value = b"bar"
changelist = [{"action": "insert", "key": key, "value": value}]
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)
with pytest.raises(Exception, match=f"Key already present: {key.hex()}"):
await data_store.insert_batch(store_id, changelist, Status.COMMITTED, use_batch_autoinsert)


@pytest.mark.anyio
@boolean_datacases(name="use_upsert", false="update with delete and insert", true="update with upsert")
async def test_update_keys(data_store: DataStore, store_id: bytes32, use_upsert: bool) -> None:
num_keys = 10
missing_keys = 50
num_values = 10
new_keys = 10
for value in range(num_values):
changelist: List[Dict[str, Any]] = []
bytes_value = value.to_bytes(4, byteorder="big")
for key in range(num_keys + missing_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "delete", "key": bytes_key})
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})
if use_upsert:
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "upsert", "key": bytes_key, "value": bytes_value})
else:
for key in range(num_keys + missing_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "delete", "key": bytes_key})
for key in range(num_keys):
bytes_key = key.to_bytes(4, byteorder="big")
changelist.append({"action": "insert", "key": bytes_key, "value": bytes_value})

await data_store.insert_batch(
store_id=store_id,
Expand Down
94 changes: 84 additions & 10 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,9 @@ async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None)
else:
await writer.execute(query, params)

async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
async def get_leaf_at_minimum_height(
self, root_hash: bytes32, hash_to_parent: Dict[bytes32, InternalNode]
) -> TerminalNode:
root_node = await self.get_node(root_hash)
queue: List[Node] = [root_node]
while True:
Expand All @@ -1378,11 +1380,29 @@ async def get_leaf_at_minimum_height(self, root_hash: bytes32) -> TerminalNode:
if isinstance(node, InternalNode):
left_node = await self.get_node(node.left_hash)
right_node = await self.get_node(node.right_hash)
hash_to_parent[left_node.hash] = node
hash_to_parent[right_node.hash] = node
queue.append(left_node)
queue.append(right_node)
elif isinstance(node, TerminalNode):
return node

async def batch_upsert(
self,
tree_id: bytes32,
hash: bytes32,
to_update_hashes: Set[bytes32],
pending_upsert_new_hashes: Dict[bytes32, bytes32],
) -> bytes32:
if hash not in to_update_hashes:
return hash
node = await self.get_node(hash)
if isinstance(node, TerminalNode):
return pending_upsert_new_hashes[hash]
new_left_hash = await self.batch_upsert(tree_id, node.left_hash, to_update_hashes, pending_upsert_new_hashes)
new_right_hash = await self.batch_upsert(tree_id, node.right_hash, to_update_hashes, pending_upsert_new_hashes)
return await self._insert_internal_node(new_left_hash, new_right_hash)

async def insert_batch(
self,
store_id: bytes32,
Expand Down Expand Up @@ -1410,14 +1430,19 @@ async def insert_batch(

key_hash_frequency: Dict[bytes32, int] = {}
first_action: Dict[bytes32, str] = {}
last_action: Dict[bytes32, str] = {}

for change in changelist:
key = change["key"]
hash = key_hash(key)
key_hash_frequency[hash] = key_hash_frequency.get(hash, 0) + 1
if hash not in first_action:
first_action[hash] = change["action"]
last_action[hash] = change["action"]

pending_autoinsert_hashes: List[bytes32] = []
pending_upsert_new_hashes: Dict[bytes32, bytes32] = {}

for change in changelist:
if change["action"] == "insert":
key = change["key"]
Expand All @@ -1435,8 +1460,16 @@ async def insert_batch(
if key_hash_frequency[hash] == 1 or (
key_hash_frequency[hash] == 2 and first_action[hash] == "delete"
):
old_node = await self.maybe_get_node_by_key(key, store_id)
terminal_node_hash = await self._insert_terminal_node(key, value)
pending_autoinsert_hashes.append(terminal_node_hash)

if old_node is None:
pending_autoinsert_hashes.append(terminal_node_hash)
else:
if key_hash_frequency[hash] == 1:
raise Exception(f"Key already present: {key.hex()}")
else:
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
continue
insert_result = await self.autoinsert(
key, value, store_id, True, Status.COMMITTED, root=latest_local_root
Expand All @@ -1458,17 +1491,50 @@ async def insert_batch(
latest_local_root = insert_result.root
elif change["action"] == "delete":
key = change["key"]
hash = key_hash(key)
if key_hash_frequency[hash] == 2 and last_action[hash] == "insert" and enable_batch_autoinsert:
continue
latest_local_root = await self.delete(key, store_id, True, Status.COMMITTED, root=latest_local_root)
elif change["action"] == "upsert":
key = change["key"]
new_value = change["value"]
hash = key_hash(key)
if key_hash_frequency[hash] == 1 and enable_batch_autoinsert:
terminal_node_hash = await self._insert_terminal_node(key, new_value)
old_node = await self.maybe_get_node_by_key(key, store_id)
if old_node is not None:
pending_upsert_new_hashes[old_node.hash] = terminal_node_hash
else:
pending_autoinsert_hashes.append(terminal_node_hash)
continue
insert_result = await self.upsert(
key, new_value, store_id, True, Status.COMMITTED, root=latest_local_root
)
latest_local_root = insert_result.root
else:
raise Exception(f"Operation in batch is not insert or delete: {change}")

if len(pending_upsert_new_hashes) > 0:
to_update_hashes: Set[bytes32] = set()
for hash in pending_upsert_new_hashes.keys():
while True:
if hash in to_update_hashes:
break
to_update_hashes.add(hash)
node = await self._get_one_ancestor(hash, store_id)
if node is None:
break
hash = node.hash
assert latest_local_root is not None
assert latest_local_root.node_hash is not None
new_root_hash = await self.batch_upsert(
store_id,
latest_local_root.node_hash,
to_update_hashes,
pending_upsert_new_hashes,
)
latest_local_root = await self._insert_root(store_id, new_root_hash, Status.COMMITTED)

# Start with the leaf nodes and pair them to form new nodes at the next level up, repeating this process
# in a bottom-up fashion until a single root node remains. This constructs a balanced tree from the leaves.
while len(pending_autoinsert_hashes) > 1:
Expand All @@ -1488,14 +1554,15 @@ async def insert_batch(
if latest_local_root is None or latest_local_root.node_hash is None:
await self._insert_root(store_id=store_id, node_hash=subtree_hash, status=Status.COMMITTED)
else:
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash)
ancestors = await self.get_ancestors_common(
node_hash=min_height_leaf.hash,
store_id=store_id,
root_hash=latest_local_root.node_hash,
generation=latest_local_root.generation,
use_optimized=True,
)
hash_to_parent: Dict[bytes32, InternalNode] = {}
min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash, hash_to_parent)
ancestors: List[InternalNode] = []
hash = min_height_leaf.hash
while hash in hash_to_parent:
node = hash_to_parent[hash]
ancestors.append(node)
hash = node.hash

await self.update_ancestor_hashes_on_insert(
store_id=store_id,
left=min_height_leaf.hash,
Expand Down Expand Up @@ -1631,6 +1698,13 @@ async def get_node_by_key_latest_generation(self, key: bytes, store_id: bytes32)
assert isinstance(node, TerminalNode)
return node

async def maybe_get_node_by_key(self, key: bytes, tree_id: bytes32) -> Optional[TerminalNode]:
try:
node = await self.get_node_by_key_latest_generation(key, tree_id)
return node
except KeyNotFoundError:
return None

async def get_node_by_key(
self,
key: bytes,
Expand Down

0 comments on commit b7e8fe7

Please sign in to comment.