Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
485 changes: 79 additions & 406 deletions crates/bindings-csharp/BSATN.Runtime/QueryBuilder.cs

Large diffs are not rendered by default.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,8 @@ public readonly struct PublicTableCols
global::PublicTable,
System.Collections.Generic.List<int>
> ListField;
public readonly global::SpacetimeDB.NullableCol<
global::PublicTable,
int
> NullableValueField;
public readonly global::SpacetimeDB.NullableCol<
global::PublicTable,
string
> NullableReferenceField;
public readonly global::SpacetimeDB.Col<global::PublicTable, int> NullableValueField;
public readonly global::SpacetimeDB.Col<global::PublicTable, string> NullableReferenceField;

internal PublicTableCols(string tableName)
{
Expand Down Expand Up @@ -357,14 +351,14 @@ internal PublicTableCols(string tableName)
global::PublicTable,
System.Collections.Generic.List<int>
>(tableName, "ListField");
NullableValueField = new global::SpacetimeDB.NullableCol<global::PublicTable, int>(
NullableValueField = new global::SpacetimeDB.Col<global::PublicTable, int>(
tableName,
"NullableValueField"
);
NullableReferenceField = new global::SpacetimeDB.NullableCol<
global::PublicTable,
string
>(tableName, "NullableReferenceField");
NullableReferenceField = new global::SpacetimeDB.Col<global::PublicTable, string>(
tableName,
"NullableReferenceField"
);
}
}

Expand Down
12 changes: 4 additions & 8 deletions crates/bindings-csharp/Codegen/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,7 @@ string ColDecl(ColumnDeclaration col)
var typeName = col.Type.Name;
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
var valueTypeName = isNullable ? typeName[..^1] : typeName;
var colType = isNullable
? "global::SpacetimeDB.NullableCol"
: "global::SpacetimeDB.Col";
var colType = isNullable ? "global::SpacetimeDB.Col" : "global::SpacetimeDB.Col";
return $"public readonly {colType}<{globalRowName}, {valueTypeName}> {col.Name};";
}

Expand All @@ -915,9 +913,7 @@ string ColInit(ColumnDeclaration col)
var typeName = col.Type.Name;
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
var valueTypeName = isNullable ? typeName[..^1] : typeName;
var colType = isNullable
? "global::SpacetimeDB.NullableCol"
: "global::SpacetimeDB.Col";
var colType = isNullable ? "global::SpacetimeDB.Col" : "global::SpacetimeDB.Col";
return $"{col.Name} = new {colType}<{globalRowName}, {valueTypeName}>(tableName, \"{col.Name}\");";
}

Expand Down Expand Up @@ -950,7 +946,7 @@ string IxColDecl(ColumnDeclaration col)
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
var valueTypeName = isNullable ? typeName[..^1] : typeName;
var colType = isNullable
? "global::SpacetimeDB.NullableIxCol"
? "global::SpacetimeDB.IxCol"
: "global::SpacetimeDB.IxCol";
return $"public readonly {colType}<{globalRowName}, {valueTypeName}> {col.Name};";
}
Expand All @@ -961,7 +957,7 @@ string IxColInit(ColumnDeclaration col)
var isNullable = typeName.EndsWith("?", StringComparison.Ordinal);
var valueTypeName = isNullable ? typeName[..^1] : typeName;
var colType = isNullable
? "global::SpacetimeDB.NullableIxCol"
? "global::SpacetimeDB.IxCol"
: "global::SpacetimeDB.IxCol";
return $"{col.Name} = new {colType}<{globalRowName}, {valueTypeName}>(tableName, \"{col.Name}\");";
}
Expand Down
36 changes: 30 additions & 6 deletions crates/bindings-typescript/src/lib/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type {
TypeBuilder,
} from './type_builders';
import type { Values } from './type_util';
import type { Bool as SatsBool } from './algebraic_type_variants';

/**
* Helper to get the set of table names.
Expand Down Expand Up @@ -65,7 +66,7 @@ type From<TableDef extends TypedTableDef> = RowTypedQuery<
Readonly<{
toSql(): string;
where(
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
): From<TableDef>;
rightSemijoin<RightTable extends TypedTableDef>(
other: TableRef<RightTable>,
Expand Down Expand Up @@ -93,7 +94,7 @@ type SemijoinBuilder<TableDef extends TypedTableDef> = RowTypedQuery<
Readonly<{
toSql(): string;
where(
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
): SemijoinBuilder<TableDef>;
/** @deprecated No longer needed — builder is already a valid query. */
build(): Query<TableDef>;
Expand All @@ -120,7 +121,7 @@ class SemijoinImpl<TableDef extends TypedTableDef>
}

where(
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
): SemijoinImpl<TableDef> {
const nextSourceQuery = this.sourceQuery.where(predicate);
return new SemijoinImpl<TableDef>(
Expand Down Expand Up @@ -167,9 +168,9 @@ class FromBuilder<TableDef extends TypedTableDef>
) {}

where(
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
): FromBuilder<TableDef> {
const newCondition = predicate(this.table.cols);
const newCondition = normalizePredicateExpr(predicate(this.table.cols));
const nextWhere = this.whereClause
? this.whereClause.and(newCondition)
: newCondition;
Expand Down Expand Up @@ -308,7 +309,7 @@ class TableRefImpl<TableDef extends TypedTableDef>
}

where(
predicate: (row: RowExpr<TableDef>) => BooleanExpr<TableDef>
predicate: (row: RowExpr<TableDef>) => PredicateExpr<TableDef>
): FromBuilder<TableDef> {
return this.asFrom().where(predicate);
}
Expand Down Expand Up @@ -628,6 +629,11 @@ export type ValueExpr<TableDef extends TypedTableDef, Value> =
| LiteralExpr<Value & LiteralValue>
| ColumnExprForValue<TableDef, Value>;

type PredicateExpr<TableDef extends TypedTableDef> =
| BooleanExpr<TableDef>
| ColumnExprForValue<TableDef, SatsBool>
| boolean;

type LiteralExpr<Value> = {
type: 'literal';
value: Value;
Expand All @@ -654,6 +660,24 @@ function normalizeValue(val: ValueInput<any>): ValueExpr<any, any> {
return literal(val as LiteralValue);
}

function normalizePredicateExpr<TableDef extends TypedTableDef>(
value: PredicateExpr<TableDef>
): BooleanExpr<TableDef> {
if (value instanceof BooleanExpr) return value;
if (typeof value === 'boolean') {
return new BooleanExpr({
type: 'eq',
left: literal(value),
right: literal(true),
});
}
return new BooleanExpr({
type: 'eq',
left: value as ValueExpr<TableDef, any>,
right: literal(true),
});
}

type EqExpr<Table extends TypedTableDef = any> = BooleanExpr<Table>;

type BooleanExprData<Table extends TypedTableDef> = (
Expand Down
8 changes: 8 additions & 0 deletions crates/bindings-typescript/tests/query.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const personTable = table(
id: t.identity(),
name: t.string(),
age: t.u32(),
active: t.bool(),
}
);

Expand Down Expand Up @@ -141,6 +142,13 @@ describe('TableScan.toSql', () => {
);
});

it('accepts boolean columns directly as where predicates', () => {
const qb = makeQueryBuilder(schemaDef);
const sql = toSql(qb.person.where(row => row.active).build());

expect(sql).toBe(`SELECT * FROM "person" WHERE "person"."active" = TRUE`);
});

it('renders Identity literals using their hex form', () => {
const qb = makeQueryBuilder(schemaDef);
const identity = new Identity(
Expand Down
8 changes: 4 additions & 4 deletions crates/codegen/src/csharp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ impl Lang for Csharp<'_> {
for (field_name, field_type) in &product_type.elements {
let prop = field_name.deref().to_case(Case::Pascal);
let (col_ty, ty) = match field_type {
AlgebraicTypeUse::Option(inner) => ("NullableCol", ty_fmt(module, inner).to_string()),
AlgebraicTypeUse::Option(inner) => ("Col", ty_fmt(module, inner).to_string()),
_ => ("Col", ty_fmt(module, field_type).to_string()),
};
writeln!(
Expand All @@ -673,7 +673,7 @@ impl Lang for Csharp<'_> {
for (field_name, field_type) in &product_type.elements {
let prop = field_name.deref().to_case(Case::Pascal);
let (col_ty, ty) = match field_type {
AlgebraicTypeUse::Option(inner) => ("NullableCol", ty_fmt(module, inner).to_string()),
AlgebraicTypeUse::Option(inner) => ("Col", ty_fmt(module, inner).to_string()),
_ => ("Col", ty_fmt(module, field_type).to_string()),
};
let col_name = field_name.deref();
Expand All @@ -694,7 +694,7 @@ impl Lang for Csharp<'_> {
}
let prop = field_name.deref().to_case(Case::Pascal);
let (col_ty, ty) = match field_type {
AlgebraicTypeUse::Option(inner) => ("NullableIxCol", ty_fmt(module, inner).to_string()),
AlgebraicTypeUse::Option(inner) => ("IxCol", ty_fmt(module, inner).to_string()),
_ => ("IxCol", ty_fmt(module, field_type).to_string()),
};
writeln!(
Expand All @@ -711,7 +711,7 @@ impl Lang for Csharp<'_> {
}
let prop = field_name.deref().to_case(Case::Pascal);
let (col_ty, ty) = match field_type {
AlgebraicTypeUse::Option(inner) => ("NullableIxCol", ty_fmt(module, inner).to_string()),
AlgebraicTypeUse::Option(inner) => ("IxCol", ty_fmt(module, inner).to_string()),
_ => ("IxCol", ty_fmt(module, field_type).to_string()),
};
let col_name = field_name.deref();
Expand Down
4 changes: 2 additions & 2 deletions crates/codegen/tests/snapshots/codegen__codegen_csharp.snap
Original file line number Diff line number Diff line change
Expand Up @@ -2035,11 +2035,11 @@ namespace SpacetimeDB

public sealed class TestDCols
{
public global::SpacetimeDB.NullableCol<TestD, NamespaceTestC> TestC { get; }
public global::SpacetimeDB.Col<TestD, NamespaceTestC> TestC { get; }

public TestDCols(string tableName)
{
TestC = new global::SpacetimeDB.NullableCol<TestD, NamespaceTestC>(tableName, "test_c");
TestC = new global::SpacetimeDB.Col<TestD, NamespaceTestC>(tableName, "test_c");
}
}

Expand Down
22 changes: 22 additions & 0 deletions crates/query-builder/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ impl<T> BoolExpr<T> {
}
}

impl<T> From<Col<T, bool>> for BoolExpr<T> {
fn from(col: Col<T, bool>) -> Self {
col.eq(true)
}
}

impl<T> From<bool> for BoolExpr<T> {
fn from(value: bool) -> Self {
if value {
BoolExpr::Eq(
Operand::Literal(LiteralValue("TRUE".to_string())),
Operand::Literal(LiteralValue("TRUE".to_string())),
)
} else {
BoolExpr::Eq(
Operand::Literal(LiteralValue("FALSE".to_string())),
Operand::Literal(LiteralValue("TRUE".to_string())),
)
}
}
}

/// Trait for types that can be used as the right-hand side of a comparison with a column of type V
/// in table T.
///
Expand Down
24 changes: 14 additions & 10 deletions crates/query-builder/src/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ impl<R: HasCols, L: HasCols> Query<R> for RightSemiJoin<R, L> {

// LeftSemiJoin where() operates on L
impl<L: HasCols> LeftSemiJoin<L> {
pub fn r#where<F>(self, f: F) -> Self
pub fn r#where<F, E>(self, f: F) -> Self
where
F: Fn(&L::Cols) -> BoolExpr<L>,
F: Fn(&L::Cols) -> E,
E: Into<BoolExpr<L>>,
{
let extra = f(&L::cols(self.left_col.table_name()));
let extra = f(&L::cols(self.left_col.table_name())).into();
let new = match self.where_expr {
Some(existing) => Some(existing.and(extra)),
None => Some(extra),
Expand All @@ -159,9 +160,10 @@ impl<L: HasCols> LeftSemiJoin<L> {
}

// Filter is an alias for where
pub fn filter<F>(self, f: F) -> Self
pub fn filter<F, E>(self, f: F) -> Self
where
F: Fn(&L::Cols) -> BoolExpr<L>,
F: Fn(&L::Cols) -> E,
E: Into<BoolExpr<L>>,
{
self.r#where(f)
}
Expand Down Expand Up @@ -189,11 +191,12 @@ impl<L: HasCols> LeftSemiJoin<L> {

// RightSemiJoin where() operates on R
impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
pub fn r#where<F>(self, f: F) -> Self
pub fn r#where<F, E>(self, f: F) -> Self
where
F: Fn(&R::Cols) -> BoolExpr<R>,
F: Fn(&R::Cols) -> E,
E: Into<BoolExpr<R>>,
{
let extra = f(&R::cols(self.right_col.table_name()));
let extra = f(&R::cols(self.right_col.table_name())).into();
let new = match self.right_where_expr {
Some(existing) => Some(existing.and(extra)),
None => Some(extra),
Expand All @@ -208,9 +211,10 @@ impl<R: HasCols, L: HasCols> RightSemiJoin<R, L> {
}

// Filter is an alias for where
pub fn filter<F>(self, f: F) -> Self
pub fn filter<F, E>(self, f: F) -> Self
where
F: Fn(&R::Cols) -> BoolExpr<R>,
F: Fn(&R::Cols) -> E,
E: Into<BoolExpr<R>>,
{
self.r#where(f)
}
Expand Down
Loading
Loading