Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,27 +245,27 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard

subdirs = defaultdict(list)
exprs = []
needed_special_columns = []
needs_key = False
for query in queries:
expr = pl.sql_expr(query)
for col in expr.meta.root_names():
if col == "__key__":
if col == "__key__" or col == '__shard_path__' or col == '__shard_offset__':
# __key__ exists in all shards
needs_key = True
needed_special_columns.append(col)
continue
subdir, field = self.fields[col]
assert col == field, "renamed fields are not supported in SQL queries yet"
subdirs[subdir].append(field)
exprs.append(expr)

# If only __key__ is in the query, we need to load shards from at least one subdir
if needs_key:
if not subdirs:
subdirs[self.fields["__key__"][0]].append('__key__')
else:
for f in subdirs.values():
f.append('__key__')
break
key_value = self.fields["__key__"]
key_subdir = key_value[0]
if needed_special_columns:
if subdirs:
key_subdir = list(subdirs.keys())[0]
subdirs[key_subdir] += needed_special_columns

if rng is None:
rng = random
Expand All @@ -284,7 +284,11 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard
for subdir, fields in subdirs.items():
shard_path = self.get_shard_path(subdir, shard)
if shard_ok:
df = scan_ipc(shard_path, glob=False).select(fields)
df = scan_ipc(
shard_path, glob=False,
include_file_paths="__shard_path__" if subdir == key_subdir else None,
row_index_name="__shard_offset__" if subdir == key_subdir else None,
).select(fields)
if subdir not in subdir_samples:
subdir_samples[subdir] = df.clear().collect()
else:
Expand Down