From 8f000c70710eabcc3987c6549f9718601afc0852 Mon Sep 17 00:00:00 2001 From: LunaStev Date: Sun, 8 Mar 2026 19:19:36 +0900 Subject: [PATCH] feat: implement generics, explicit type casting, and static globals This commit introduces the foundational infrastructure for generic programming, explicit type casting via the `as` operator, and global static variables. It also significantly matures the target-specific preprocessing and backend control. Changes: - **Generics Support**: - **AST & Parser**: Added `generic_params` to `FunctionNode` and `StructNode`. Added `type_args` to function calls. - **Parsing**: Implemented `parse_generic_param_names` and enhanced the type parser to handle nested generic arguments and top-level split logic. - **Stdlib**: Refactored core modules to provide generic utilities, including `TypedBuffer`, `ptr_swap`, and generic math functions (`num_abs`, `num_min`, etc.). - **Monomorphization**: Integrated a monomorphization pass into the compilation runner. - **Language Enhancements**: - **Type Casting**: Introduced the `as` keyword and `Expression::Cast` to support explicit type conversions in both runtime and compile-time contexts. - **Statics**: Added the `static` keyword for global variables, including backend support for global LLVM symbols. - **Null Support**: Added a native `null` literal and ensured type compatibility with pointers. - **Backend & Cross-Platform**: - **Target Control**: Expanded the `BackendOptions` to allow overriding target triples, CPU architectures, and ABIs. - **ABI Refinement**: Updated `abi_c.rs` to support both x86_64 and Arm64 for Linux and Darwin (macOS), including improved aggregate splitting logic for ARM64. - **Robust Preprocessing**: Overhauled `#[target(os="...")]` preprocessing to correctly ignore braces and semicolons inside comments or string literals. - **CLI & Documentation**: - Updated `README.md` to document the new **Tiered Platform Policy** and added a section for building from source. - Expanded the CLI with advanced `--llvm` and linker options. - Added a reference to the `doom.wave` example. These features significantly increase the expressiveness of the Wave language and provide the tools necessary for writing reusable, platform-aware systems code. Signed-off-by: LunaStev --- examples/observability_gateway.wave | 394 +++++++++ front/parser/src/ast.rs | 3 + front/parser/src/expr/assign.rs | 4 +- front/parser/src/expr/binary.rs | 22 +- front/parser/src/expr/postfix.rs | 2 +- front/parser/src/expr/primary.rs | 101 ++- front/parser/src/expr/unary.rs | 2 +- front/parser/src/generics.rs | 812 +++++++++++++++++++ front/parser/src/lib.rs | 1 + front/parser/src/parser/decl.rs | 9 + front/parser/src/parser/expr.rs | 6 +- front/parser/src/parser/functions.rs | 93 ++- front/parser/src/parser/items.rs | 5 +- front/parser/src/parser/types.rs | 72 +- front/parser/tests/parse_var_and_generics.rs | 50 ++ llvm/src/expression/rvalue/calls.rs | 8 + llvm/src/expression/rvalue/dispatch.rs | 8 +- src/runner.rs | 38 + std/buffer/alloc.wave | 96 +++ std/buffer/read.wave | 20 + std/buffer/types.wave | 7 + std/buffer/write.wave | 29 + std/env/environ.wave | 75 +- std/env/parse.wave | 35 + std/math/float.wave | 30 +- std/math/int.wave | 38 +- std/math/trig.wave | 8 +- std/mem/alloc.wave | 79 ++ std/mem/cstr.wave | 54 ++ std/mem/ops.wave | 86 ++ std/net/address.wave | 74 ++ std/net/socket_base.wave | 133 +++ std/net/tcp.wave | 194 +++-- std/net/udp.wave | 170 ++-- test/test91.wave | 77 ++ 35 files changed, 2548 insertions(+), 287 deletions(-) create mode 100644 examples/observability_gateway.wave create mode 100644 front/parser/src/generics.rs create mode 100644 front/parser/tests/parse_var_and_generics.rs create mode 100644 std/net/address.wave create mode 100644 std/net/socket_base.wave create mode 100644 test/test91.wave diff --git a/examples/observability_gateway.wave b/examples/observability_gateway.wave new file mode 100644 index 00000000..54c517a9 --- /dev/null +++ b/examples/observability_gateway.wave @@ -0,0 +1,394 @@ +import("std::env::cwd"); +import("std::env::environ"); +import("std::path::copy"); +import("std::string::len"); +import("std::string::trim"); +import("std::string::ascii"); +import("std::string::hash"); +import("std::math::int"); +import("std::math::bits"); +import("std::math::float"); +import("std::math::num"); +import("std::time::clock"); +import("std::time::diff"); +import("std::time::sleep"); +import("std::buffer::alloc"); +import("std::buffer::write"); +import("std::buffer::read"); +import("std::mem::ops"); +import("std::net::tcp"); + +const DEFAULT_PORT: i32 = 18080; +const ALERT_THRESHOLD: i32 = 3; +const MAX_ROUTE_SCORE: i32 = 100; + +static GLOBAL_EVENT_SEQ: i64 = 0; + +type TenantId = i64; + +enum Severity -> i32 { + INFO = 1, + WARN = 2, + ERROR = 3 +} + +#[target(os="linux")] +fun default_data_root() -> str { + return "/var/lib/wave"; +} + +#[target(os="macos")] +fun default_data_root() -> str { + return "/usr/local/var/wave"; +} + +struct KeyValue { + key: K; + value: V; +} + +struct Window { + at_ns: i64; + data: T; +} + +struct ApiResult { + ok: bool; + data: T; +} + +struct RequestCtx { + tenant: TenantId; + method: str; + route: str; +} + +struct Metrics { + total: i32; + info: i32; + warn: i32; + error: i32; +} + +struct IngestRecord { + seq: i64; + ctx: RequestCtx; + severity: Severity; + message: str; + took_ns: i64; + route_score: i32; + route_hash: i64; +} + +proto Metrics { + fun record(self: Metrics, sev: Severity) -> Metrics { + var next: Metrics = self; + next.total += 1; + + if (sev == INFO) { + next.info += 1; + } else if (sev == WARN) { + next.warn += 1; + } else { + next.error += 1; + } + + return next; + } + + fun is_alert(self: Metrics, threshold: i32) -> bool { + if (self.error >= threshold) { + return true; + } + return false; + } +} + +proto RequestCtx { + fun is_api(self: RequestCtx) -> bool { + return is_api_route(self.route); + } +} + +fun identity(x: T) -> T { + return x; +} + +fun choose(left: T, right: T, pick_left: bool) -> T { + if (pick_left) { + return left; + } + return right; +} + +fun make_kv(k: K, v: V) -> KeyValue { + var pair_value: KeyValue; + pair_value.key = k; + pair_value.value = v; + return pair_value; +} + +fun wrap_window(data: T, at_ns: i64) -> Window { + var w: Window; + w.at_ns = at_ns; + w.data = data; + return w; +} + +fun ok(payload: T) -> ApiResult { + var r: ApiResult; + r.ok = true; + r.data = payload; + return r; +} + +fun is_api_route(route: str) -> bool { + var n: i32 = len(route); + if (n < 5) { + return false; + } + + if (route[0] != 47) { + return false; + } + + if (to_lower(route[1]) != 97) { + return false; + } + + if (to_lower(route[2]) != 112) { + return false; + } + + if (to_lower(route[3]) != 105) { + return false; + } + + if (route[4] != 47) { + return false; + } + + return true; +} + +fun parse_severity(line: str) -> Severity { + var i: i32 = trim_left_index(line); + var c: u8 = to_upper(line[i]); + + if (c == 69) { + return ERROR; + } + + if (c == 87) { + return WARN; + } + + return INFO; +} + +fun score_route(route: str) -> i32 { + var n: i32 = len(route); + var score: i32 = clamp(n * 2, 0, MAX_ROUTE_SCORE); + + if (path_has_ext(route)) { + score += 12; + } + + if (!path_is_abs(route)) { + score += 4; + } + + var ext_start: i32 = path_ext_start(route); + if (ext_start >= 0) { + var tail: i32 = n - ext_start; + score += min(tail * 3, 20); + } + + score += popcount(n); + score = align_up(score, 4); + score = clamp(score, 0, MAX_ROUTE_SCORE); + + return score; +} + +fun build_log_frame(ctx: RequestCtx, sev: Severity) -> i64 { + var frame: Buffer = buffer_new_default(); + + var tenant_kv: KeyValue = make_kv("tenant", ctx.tenant); + + buffer_append_str(&frame, tenant_kv.key); + buffer_push(&frame, 61); + + if (sev == ERROR) { + buffer_append_str(&frame, "ERROR"); + } else if (sev == WARN) { + buffer_append_str(&frame, "WARN"); + } else { + buffer_append_str(&frame, "INFO"); + } + + buffer_push(&frame, 124); + buffer_append_str(&frame, ctx.route); + + var first_byte: u8 = buffer_at(frame, 0); + + var a: array; + var b: array; + + mem_set(&a[0], first_byte, 4); + mem_copy(&b[0], &a[0], 4); + + var ok_mem: bool = (mem_cmp(&a[0], &b[0], 4) == 0); + var used: i64 = frame.len; + + if (!ok_mem) { + used = -1; + } + + buffer_free(&frame); + return used; +} + +fun ingest_one( + ctx: RequestCtx, + line: str, + base_metrics: Metrics +) -> ApiResult> { + var start_ts: TimeSpec; + time_now_monotonic(&start_ts); + + GLOBAL_EVENT_SEQ += 1; + + var sev: Severity = parse_severity(line); + var next_metrics: Metrics = base_metrics.record(sev); + + var end_ts: TimeSpec; + time_now_monotonic(&end_ts); + + var took_ns: i64 = time_diff_ns(start_ts, end_ts); + + var rec: IngestRecord; + rec.seq = GLOBAL_EVENT_SEQ; + rec.ctx = ctx; + rec.severity = sev; + rec.message = line; + rec.took_ns = took_ns; + rec.route_score = score_route(ctx.route); + rec.route_hash = fnv1a_64(ctx.route); + + var wrapped: Window = wrap_window( + rec, + time_now_monotonic_ns() + ); + + if (next_metrics.is_alert(ALERT_THRESHOLD)) { + println("ALERT tenant={} errors={}", ctx.tenant, next_metrics.error); + } + + return ok>(wrapped); +} + +fun bootstrap_runtime() { + var cwd_buf: array; + var cwd_len: i64 = env_getcwd(&cwd_buf[0], 512); + + var env_port_raw: array; + var env_port_len: i64 = env_get("WAVE_PORT", &env_port_raw[0], 64); + + var configured_port: i32 = env_get_i32_default("WAVE_PORT", DEFAULT_PORT); + var port: i32 = clamp(configured_port, 1024, 65535); + + var listener: TcpListener = tcp_bind(port as i16); + tcp_close_listener(listener); + + var root: str = default_data_root(); + + var log_path_buf: array; + var base_buf: array; + var dir_buf: array; + + var log_path_len: i32 = path_join2(&log_path_buf[0], 256, root, "spool/events.log"); + var base_len: i32 = path_basename_copy(&base_buf[0], 128, root); + var dir_len: i32 = path_dirname_copy(&dir_buf[0], 128, root); + + println("platform root: {}", root); + println("cwd length: {}", cwd_len); + println("env WAVE_PORT len: {}", env_port_len); + println("log path bytes: {}", log_path_len); + println("base len: {} dirname len: {}", base_len, dir_len); + println("has HOME? {}", env_exists("HOME")); +} + +fun main() { + bootstrap_runtime(); + + var metrics: Metrics = Metrics { + total: 0, + info: 0, + warn: 0, + error: 0 + }; + + var ctx: RequestCtx = RequestCtx { + tenant: 42, + method: "POST", + route: "/api/v1/orders/create.csv" + }; + + if (ctx.is_api()) { + println("api route accepted: {}", ctx.route); + } + + var r1: ApiResult> = ingest_one( + ctx, + "WARN order latency reached 240ms", + metrics + ); + + metrics = metrics.record(r1.data.data.severity); + + var r2: ApiResult> = ingest_one( + ctx, + "ERROR payment provider timeout", + metrics + ); + + metrics = metrics.record(r2.data.data.severity); + + var normalized_score: i32 = identity(r2.data.data.route_score); + var stronger_score: i32 = choose(normalized_score, 64, normalized_score > 64); + + var jitter_raw: i32 = abs((r2.data.data.took_ns % 97) as i32); + var jitter_aligned: i32 = align_down(jitter_raw + 7, 4); + var jitter_factor: f32 = clamp_f32((jitter_aligned as f32) / 10.0, 0.1, 10.0); + + var frame_size: i64 = build_log_frame(ctx, r2.data.data.severity); + var shard: i32 = ilog2_ceil(max(1, stronger_score)); + var bucket: i32 = gcd((r2.data.data.route_hash % 1000) as i32, 360); + + println( + "seq={} score={} frame={} took={}ns", + r2.data.data.seq, + stronger_score, + frame_size, + r2.data.data.took_ns + ); + + println( + "hash={} shard={} bucket={} jitter={}", + r2.data.data.route_hash, + shard, + bucket, + jitter_factor + ); + + println( + "metrics total={} info={} warn={} error={}", + metrics.total, + metrics.info, + metrics.warn, + metrics.error + ); + + time_sleep_ms(1); +} diff --git a/front/parser/src/ast.rs b/front/parser/src/ast.rs index eb300bae..2560ffaa 100644 --- a/front/parser/src/ast.rs +++ b/front/parser/src/ast.rs @@ -69,6 +69,7 @@ pub struct EnumVariantNode { #[derive(Debug, Clone)] pub struct FunctionNode { pub name: String, + pub generic_params: Vec, pub parameters: Vec, pub return_type: Option, pub body: Vec, @@ -77,6 +78,7 @@ pub struct FunctionNode { #[derive(Debug, Clone)] pub struct StructNode { pub name: String, + pub generic_params: Vec, pub fields: Vec<(String, WaveType)>, pub methods: Vec, } @@ -132,6 +134,7 @@ pub enum Expression { }, FunctionCall { name: String, + type_args: Vec, args: Vec, }, MethodCall { diff --git a/front/parser/src/expr/assign.rs b/front/parser/src/expr/assign.rs index 3ae84a7b..f8b89e95 100644 --- a/front/parser/src/expr/assign.rs +++ b/front/parser/src/expr/assign.rs @@ -16,14 +16,14 @@ use lexer::Token; pub fn parse_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { parse_assignment_expression(tokens) } pub fn parse_assignment_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let left = parse_logical_or_expression(tokens)?; diff --git a/front/parser/src/expr/binary.rs b/front/parser/src/expr/binary.rs index 9784828d..67765207 100644 --- a/front/parser/src/expr/binary.rs +++ b/front/parser/src/expr/binary.rs @@ -17,7 +17,7 @@ use lexer::Token; pub fn parse_logical_or_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_logical_and_expression(tokens)?; @@ -41,7 +41,7 @@ pub fn parse_logical_and_expression<'a, T>( tokens: &mut std::iter::Peekable, ) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_bitwise_or_expression(tokens)?; @@ -63,7 +63,7 @@ where pub fn parse_bitwise_or_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_bitwise_xor_expression(tokens)?; @@ -87,7 +87,7 @@ pub fn parse_bitwise_xor_expression<'a, T>( tokens: &mut std::iter::Peekable, ) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_bitwise_and_expression(tokens)?; @@ -108,7 +108,7 @@ pub fn parse_bitwise_and_expression<'a, T>( tokens: &mut std::iter::Peekable, ) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_equality_expression(tokens)?; @@ -130,7 +130,7 @@ where pub fn parse_equality_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_relational_expression(tokens)?; @@ -154,7 +154,7 @@ where pub fn parse_relational_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_shift_expression(tokens)?; @@ -180,7 +180,7 @@ where pub fn parse_shift_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_additive_expression(tokens)?; @@ -205,7 +205,7 @@ where pub fn parse_additive_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_multiplicative_expression(tokens)?; @@ -231,7 +231,7 @@ pub fn parse_multiplicative_expression<'a, T>( tokens: &mut std::iter::Peekable, ) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut left = parse_cast_expression(tokens)?; @@ -256,7 +256,7 @@ where fn parse_cast_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let mut expr = parse_unary_expression(tokens)?; diff --git a/front/parser/src/expr/postfix.rs b/front/parser/src/expr/postfix.rs index 2733e7cd..9d1322bf 100644 --- a/front/parser/src/expr/postfix.rs +++ b/front/parser/src/expr/postfix.rs @@ -22,7 +22,7 @@ pub fn parse_postfix_expression<'a, T>( mut expr: Expression, ) -> Option where - T: Iterator, + T: Iterator + Clone, { loop { match tokens.peek().map(|t| &t.token_type) { diff --git a/front/parser/src/expr/primary.rs b/front/parser/src/expr/primary.rs index 6526d552..3578dd88 100644 --- a/front/parser/src/expr/primary.rs +++ b/front/parser/src/expr/primary.rs @@ -16,12 +16,47 @@ use lexer::Token; use crate::asm::{parse_asm_clobber_clause, parse_asm_inout_clause}; use crate::ast::{Expression, Literal}; +use crate::decl::collect_generic_inner; use crate::expr::parse_expression; use crate::expr::postfix::parse_postfix_expression; +use crate::types::{parse_type, split_top_level_generic_args, token_type_to_wave_type}; + +fn skip_ws<'a, T>(tokens: &mut Peekable) +where + T: Iterator + Clone, +{ + while matches!( + tokens.peek().map(|t| &t.token_type), + Some(TokenType::Whitespace | TokenType::Newline) + ) { + tokens.next(); + } +} + +fn peek_is_generic_call<'a, T>(tokens: &Peekable) -> bool +where + T: Iterator + Clone, +{ + let mut probe = tokens.clone(); + if !matches!(probe.peek().map(|t| &t.token_type), Some(TokenType::Lchevr)) { + return false; + } + probe.next(); // '<' + if collect_generic_inner(&mut probe).is_none() { + return false; + } + while matches!( + probe.peek().map(|t| &t.token_type), + Some(TokenType::Whitespace | TokenType::Newline) + ) { + probe.next(); + } + matches!(probe.peek().map(|t| &t.token_type), Some(TokenType::Lparen)) +} pub fn parse_primary_expression<'a, T>(tokens: &mut Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { let token = (*tokens.peek()?).clone(); @@ -52,6 +87,64 @@ where let expr = if let Some(peeked_token) = tokens.peek() { match &peeked_token.token_type { + TokenType::Lchevr if peek_is_generic_call(tokens) => { + tokens.next(); // consume '<' + let inner = collect_generic_inner(tokens)?; + let arg_strs = split_top_level_generic_args(&inner)?; + + let mut type_args = Vec::with_capacity(arg_strs.len()); + for arg in arg_strs { + let tt = parse_type(&arg)?; + let wt = token_type_to_wave_type(&tt)?; + type_args.push(wt); + } + + skip_ws(tokens); + if tokens + .peek() + .map_or(true, |t| t.token_type != TokenType::Lparen) + { + println!("Error: Expected '(' after generic function type arguments"); + return None; + } + tokens.next(); // consume '(' + + let mut args = vec![]; + if tokens + .peek() + .map_or(false, |t| t.token_type != TokenType::Rparen) + { + loop { + let arg = parse_expression(tokens)?; + args.push(arg); + + if let Some(Token { + token_type: TokenType::Comma, + .. + }) = tokens.peek() + { + tokens.next(); + } else { + break; + } + } + } + + if tokens + .peek() + .map_or(true, |t| t.token_type != TokenType::Rparen) + { + println!("Error: Expected ')' after function call arguments"); + return None; + } + tokens.next(); + + Expression::FunctionCall { + name, + type_args, + args, + } + } TokenType::Lparen => { tokens.next(); @@ -85,7 +178,11 @@ where } tokens.next(); - Expression::FunctionCall { name, args } + Expression::FunctionCall { + name, + type_args: Vec::new(), + args, + } } TokenType::Lbrace => { tokens.next(); diff --git a/front/parser/src/expr/unary.rs b/front/parser/src/expr/unary.rs index 963d09ef..b7e47ab5 100644 --- a/front/parser/src/expr/unary.rs +++ b/front/parser/src/expr/unary.rs @@ -17,7 +17,7 @@ use lexer::Token; pub fn parse_unary_expression<'a, T>(tokens: &mut std::iter::Peekable) -> Option where - T: Iterator, + T: Iterator + Clone, { if let Some(token) = tokens.peek() { match token.token_type { diff --git a/front/parser/src/generics.rs b/front/parser/src/generics.rs new file mode 100644 index 00000000..bdb038c8 --- /dev/null +++ b/front/parser/src/generics.rs @@ -0,0 +1,812 @@ +// This file is part of the Wave language project. +// Copyright (c) 2024–2026 Wave Foundation +// Copyright (c) 2024–2026 LunaStev and contributors +// +// This Source Code Form is subject to the terms of the +// Mozilla Public License, v. 2.0. +// If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +use crate::ast::{ + ASTNode, EnumNode, Expression, ExternFunctionNode, FunctionNode, MatchArm, MatchPattern, + ParameterNode, ProtoImplNode, StatementNode, StructNode, TypeAliasNode, VariableNode, WaveType, +}; +use crate::types::{parse_type, split_top_level_generic_args, token_type_to_wave_type}; +use std::collections::{BTreeMap, HashMap, HashSet}; + +#[derive(Default)] +struct GenericEnv { + function_templates: HashMap, + struct_templates: HashMap, + + function_instances: BTreeMap, + struct_instances: BTreeMap, + + function_in_progress: HashSet, + struct_in_progress: HashSet, +} + +pub fn monomorphize_generics(ast: Vec) -> Result, String> { + let mut env = GenericEnv::default(); + + for node in &ast { + match node { + ASTNode::Function(f) if !f.generic_params.is_empty() => { + if env + .function_templates + .insert(f.name.clone(), f.clone()) + .is_some() + { + return Err(format!("duplicate generic function template '{}'", f.name)); + } + } + ASTNode::Struct(s) if !s.generic_params.is_empty() => { + if env + .struct_templates + .insert(s.name.clone(), s.clone()) + .is_some() + { + return Err(format!("duplicate generic struct template '{}'", s.name)); + } + } + _ => {} + } + } + + let mut out: Vec = Vec::new(); + let empty_subst: HashMap = HashMap::new(); + + for node in ast { + match node { + ASTNode::Function(f) => { + if f.generic_params.is_empty() { + out.push(ASTNode::Function(rewrite_function( + f, + &empty_subst, + &mut env, + )?)); + } + } + ASTNode::Struct(s) => { + if s.generic_params.is_empty() { + out.push(ASTNode::Struct(rewrite_struct(s, &empty_subst, &mut env)?)); + } + } + ASTNode::Variable(v) => { + out.push(ASTNode::Variable(rewrite_variable( + v, + &empty_subst, + &mut env, + )?)); + } + ASTNode::ExternFunction(e) => { + out.push(ASTNode::ExternFunction(rewrite_extern( + e, + &empty_subst, + &mut env, + )?)); + } + ASTNode::ProtoImpl(p) => { + out.push(ASTNode::ProtoImpl(rewrite_proto( + p, + &empty_subst, + &mut env, + )?)); + } + ASTNode::TypeAlias(TypeAliasNode { name, target }) => { + out.push(ASTNode::TypeAlias(TypeAliasNode { + name, + target: rewrite_wave_type(&target, &empty_subst, &mut env)?, + })); + } + ASTNode::Enum(EnumNode { + name, + repr_type, + variants, + }) => { + out.push(ASTNode::Enum(EnumNode { + name, + repr_type: rewrite_wave_type(&repr_type, &empty_subst, &mut env)?, + variants, + })); + } + ASTNode::Statement(stmt) => { + out.push(ASTNode::Statement(rewrite_statement( + stmt, + &empty_subst, + &mut env, + )?)); + } + ASTNode::Expression(expr) => { + out.push(ASTNode::Expression(rewrite_expression( + expr, + &empty_subst, + &mut env, + )?)); + } + ASTNode::Program(p) => out.push(ASTNode::Program(p)), + } + } + + for (_, s) in env.struct_instances { + out.push(ASTNode::Struct(s)); + } + for (_, f) in env.function_instances { + out.push(ASTNode::Function(f)); + } + + Ok(out) +} + +fn rewrite_parameter( + param: ParameterNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + Ok(ParameterNode { + name: param.name, + param_type: rewrite_wave_type(¶m.param_type, subst, env)?, + initial_value: param.initial_value, + }) +} + +fn rewrite_function( + mut f: FunctionNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + if !f.generic_params.is_empty() { + return Err(format!( + "internal error: unresolved generic params in function '{}': {:?}", + f.name, f.generic_params + )); + } + + f.parameters = f + .parameters + .into_iter() + .map(|p| rewrite_parameter(p, subst, env)) + .collect::, _>>()?; + + f.return_type = f + .return_type + .as_ref() + .map(|t| rewrite_wave_type(t, subst, env)) + .transpose()?; + + f.body = f + .body + .into_iter() + .map(|n| rewrite_node(n, subst, env)) + .collect::, _>>()?; + + Ok(f) +} + +fn rewrite_struct( + mut s: StructNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + if !s.generic_params.is_empty() { + return Err(format!( + "internal error: unresolved generic params in struct '{}': {:?}", + s.name, s.generic_params + )); + } + + s.fields = s + .fields + .into_iter() + .map(|(n, t)| Ok((n, rewrite_wave_type(&t, subst, env)?))) + .collect::, String>>()?; + + s.methods = s + .methods + .into_iter() + .map(|m| { + if !m.generic_params.is_empty() { + return Err(format!( + "generic methods are not supported yet: '{}::{}'", + s.name, m.name + )); + } + rewrite_function(m, subst, env) + }) + .collect::, _>>()?; + + Ok(s) +} + +fn rewrite_proto( + mut p: ProtoImplNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + p.methods = p + .methods + .into_iter() + .map(|m| { + if !m.generic_params.is_empty() { + return Err(format!( + "generic methods are not supported yet: 'proto {}::{}'", + p.target, m.name + )); + } + rewrite_function(m, subst, env) + }) + .collect::, _>>()?; + Ok(p) +} + +fn rewrite_extern( + mut e: ExternFunctionNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + e.params = e + .params + .into_iter() + .map(|(n, t)| Ok((n, rewrite_wave_type(&t, subst, env)?))) + .collect::, String>>()?; + e.return_type = rewrite_wave_type(&e.return_type, subst, env)?; + Ok(e) +} + +fn rewrite_variable( + mut v: VariableNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + v.type_name = rewrite_wave_type(&v.type_name, subst, env)?; + v.initial_value = v + .initial_value + .as_ref() + .map(|e| rewrite_expression(e.clone(), subst, env)) + .transpose()?; + Ok(v) +} + +fn rewrite_node( + node: ASTNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + match node { + ASTNode::Variable(v) => Ok(ASTNode::Variable(rewrite_variable(v, subst, env)?)), + ASTNode::Statement(s) => Ok(ASTNode::Statement(rewrite_statement(s, subst, env)?)), + ASTNode::Expression(e) => Ok(ASTNode::Expression(rewrite_expression(e, subst, env)?)), + ASTNode::Function(f) => Ok(ASTNode::Function(rewrite_function(f, subst, env)?)), + ASTNode::Struct(s) => Ok(ASTNode::Struct(rewrite_struct(s, subst, env)?)), + ASTNode::ExternFunction(e) => Ok(ASTNode::ExternFunction(rewrite_extern(e, subst, env)?)), + ASTNode::ProtoImpl(p) => Ok(ASTNode::ProtoImpl(rewrite_proto(p, subst, env)?)), + ASTNode::TypeAlias(TypeAliasNode { name, target }) => { + Ok(ASTNode::TypeAlias(TypeAliasNode { + name, + target: rewrite_wave_type(&target, subst, env)?, + })) + } + ASTNode::Enum(EnumNode { + name, + repr_type, + variants, + }) => Ok(ASTNode::Enum(EnumNode { + name, + repr_type: rewrite_wave_type(&repr_type, subst, env)?, + variants, + })), + ASTNode::Program(p) => Ok(ASTNode::Program(p)), + } +} + +fn rewrite_statement( + stmt: StatementNode, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + match stmt { + StatementNode::PrintFormat { format, args } => Ok(StatementNode::PrintFormat { + format, + args: rewrite_expr_list(args, subst, env)?, + }), + StatementNode::PrintlnFormat { format, args } => Ok(StatementNode::PrintlnFormat { + format, + args: rewrite_expr_list(args, subst, env)?, + }), + StatementNode::Input { format, args } => Ok(StatementNode::Input { + format, + args: rewrite_expr_list(args, subst, env)?, + }), + StatementNode::If { + condition, + body, + else_if_blocks, + else_block, + } => Ok(StatementNode::If { + condition: rewrite_expression(condition, subst, env)?, + body: rewrite_node_list(body, subst, env)?, + else_if_blocks: else_if_blocks + .map(|blocks| { + blocks + .into_iter() + .map(|(cond, body)| { + Ok(( + rewrite_expression(cond, subst, env)?, + rewrite_node_list(body, subst, env)?, + )) + }) + .collect::, String>>() + }) + .transpose()? + .map(Box::new), + else_block: else_block + .map(|body| rewrite_node_list(*body, subst, env).map(Box::new)) + .transpose()?, + }), + StatementNode::For { + initialization, + condition, + increment, + body, + } => Ok(StatementNode::For { + initialization: Box::new(rewrite_node(*initialization, subst, env)?), + condition: rewrite_expression(condition, subst, env)?, + increment: rewrite_expression(increment, subst, env)?, + body: rewrite_node_list(body, subst, env)?, + }), + StatementNode::While { condition, body } => Ok(StatementNode::While { + condition: rewrite_expression(condition, subst, env)?, + body: rewrite_node_list(body, subst, env)?, + }), + StatementNode::Match { value, arms } => Ok(StatementNode::Match { + value: rewrite_expression(value, subst, env)?, + arms: arms + .into_iter() + .map(|arm| { + Ok(MatchArm { + pattern: rewrite_pattern(arm.pattern), + body: rewrite_node_list(arm.body, subst, env)?, + }) + }) + .collect::, String>>()?, + }), + StatementNode::Assign { variable, value } => Ok(StatementNode::Assign { + variable, + value: rewrite_expression(value, subst, env)?, + }), + StatementNode::AsmBlock { + instructions, + inputs, + outputs, + clobbers, + } => Ok(StatementNode::AsmBlock { + instructions, + inputs: inputs + .into_iter() + .map(|(r, e)| Ok((r, rewrite_expression(e, subst, env)?))) + .collect::, String>>()?, + outputs: outputs + .into_iter() + .map(|(r, e)| Ok((r, rewrite_expression(e, subst, env)?))) + .collect::, String>>()?, + clobbers, + }), + StatementNode::Expression(e) => Ok(StatementNode::Expression(rewrite_expression( + e, subst, env, + )?)), + StatementNode::Return(v) => Ok(StatementNode::Return( + v.map(|e| rewrite_expression(e, subst, env)).transpose()?, + )), + StatementNode::Print(s) => Ok(StatementNode::Print(s)), + StatementNode::Println(s) => Ok(StatementNode::Println(s)), + StatementNode::Variable(v) => Ok(StatementNode::Variable(v)), + StatementNode::Import(s) => Ok(StatementNode::Import(s)), + StatementNode::Break => Ok(StatementNode::Break), + StatementNode::Continue => Ok(StatementNode::Continue), + } +} + +fn rewrite_pattern(pattern: MatchPattern) -> MatchPattern { + pattern +} + +fn rewrite_expr_list( + exprs: Vec, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result, String> { + exprs + .into_iter() + .map(|e| rewrite_expression(e, subst, env)) + .collect() +} + +fn rewrite_node_list( + nodes: Vec, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result, String> { + nodes + .into_iter() + .map(|n| rewrite_node(n, subst, env)) + .collect() +} + +fn rewrite_expression( + expr: Expression, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + match expr { + Expression::FunctionCall { + name, + type_args, + args, + } => { + let args = rewrite_expr_list(args, subst, env)?; + + if type_args.is_empty() { + if env.function_templates.contains_key(&name) { + return Err(format!( + "generic function '{}' requires explicit type arguments", + name + )); + } + return Ok(Expression::FunctionCall { + name, + type_args, + args, + }); + } + + let concrete_args: Vec = type_args + .iter() + .map(|t| rewrite_wave_type(t, subst, env)) + .collect::, _>>()?; + + if !env.function_templates.contains_key(&name) { + return Err(format!( + "type arguments provided for non-generic function '{}'", + name + )); + } + + let instantiated = ensure_function_instance(&name, &concrete_args, env)?; + Ok(Expression::FunctionCall { + name: instantiated, + type_args: Vec::new(), + args, + }) + } + Expression::MethodCall { object, name, args } => Ok(Expression::MethodCall { + object: Box::new(rewrite_expression(*object, subst, env)?), + name, + args: rewrite_expr_list(args, subst, env)?, + }), + Expression::StructLiteral { name, fields } => { + let rewritten_name = rewrite_struct_name_usage(&name, subst, env)?; + let mut rewritten_fields = Vec::with_capacity(fields.len()); + for (fname, value) in fields { + rewritten_fields.push((fname, rewrite_expression(value, subst, env)?)); + } + Ok(Expression::StructLiteral { + name: rewritten_name, + fields: rewritten_fields, + }) + } + Expression::Deref(inner) => Ok(Expression::Deref(Box::new(rewrite_expression( + *inner, subst, env, + )?))), + Expression::AddressOf(inner) => Ok(Expression::AddressOf(Box::new(rewrite_expression( + *inner, subst, env, + )?))), + Expression::BinaryExpression { + left, + operator, + right, + } => Ok(Expression::BinaryExpression { + left: Box::new(rewrite_expression(*left, subst, env)?), + operator, + right: Box::new(rewrite_expression(*right, subst, env)?), + }), + Expression::IndexAccess { target, index } => Ok(Expression::IndexAccess { + target: Box::new(rewrite_expression(*target, subst, env)?), + index: Box::new(rewrite_expression(*index, subst, env)?), + }), + Expression::ArrayLiteral(items) => Ok(Expression::ArrayLiteral(rewrite_expr_list( + items, subst, env, + )?)), + Expression::Grouped(inner) => Ok(Expression::Grouped(Box::new(rewrite_expression( + *inner, subst, env, + )?))), + Expression::AssignOperation { + target, + operator, + value, + } => Ok(Expression::AssignOperation { + target: Box::new(rewrite_expression(*target, subst, env)?), + operator, + value: Box::new(rewrite_expression(*value, subst, env)?), + }), + Expression::Assignment { target, value } => Ok(Expression::Assignment { + target: Box::new(rewrite_expression(*target, subst, env)?), + value: Box::new(rewrite_expression(*value, subst, env)?), + }), + Expression::AsmBlock { + instructions, + inputs, + outputs, + clobbers, + } => Ok(Expression::AsmBlock { + instructions, + inputs: inputs + .into_iter() + .map(|(r, e)| Ok((r, rewrite_expression(e, subst, env)?))) + .collect::, String>>()?, + outputs: outputs + .into_iter() + .map(|(r, e)| Ok((r, rewrite_expression(e, subst, env)?))) + .collect::, String>>()?, + clobbers, + }), + Expression::FieldAccess { object, field } => Ok(Expression::FieldAccess { + object: Box::new(rewrite_expression(*object, subst, env)?), + field, + }), + Expression::Unary { operator, expr } => Ok(Expression::Unary { + operator, + expr: Box::new(rewrite_expression(*expr, subst, env)?), + }), + Expression::Cast { expr, target_type } => Ok(Expression::Cast { + expr: Box::new(rewrite_expression(*expr, subst, env)?), + target_type: rewrite_wave_type(&target_type, subst, env)?, + }), + Expression::IncDec { kind, target } => Ok(Expression::IncDec { + kind, + target: Box::new(rewrite_expression(*target, subst, env)?), + }), + other => Ok(other), + } +} + +fn rewrite_wave_type( + ty: &WaveType, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + match ty { + WaveType::Pointer(inner) => Ok(WaveType::Pointer(Box::new(rewrite_wave_type( + inner, subst, env, + )?))), + WaveType::Array(inner, n) => Ok(WaveType::Array( + Box::new(rewrite_wave_type(inner, subst, env)?), + *n, + )), + WaveType::Struct(name) => rewrite_struct_type(name, subst, env), + _ => Ok(ty.clone()), + } +} + +fn rewrite_struct_type( + name: &str, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + if let Some(mapped) = subst.get(name) { + return Ok(mapped.clone()); + } + + if let Some((base, arg_strs)) = parse_type_application(name)? { + if !env.struct_templates.contains_key(&base) { + return Err(format!( + "generic type '{}' is not declared as a generic struct", + base + )); + } + + let mut concrete_args: Vec = Vec::with_capacity(arg_strs.len()); + for arg in arg_strs { + let parsed = parse_wave_type_from_str(&arg)?; + concrete_args.push(rewrite_wave_type(&parsed, subst, env)?); + } + + let instantiated = ensure_struct_instance(&base, &concrete_args, env)?; + return Ok(WaveType::Struct(instantiated)); + } + + if let Some(template) = env.struct_templates.get(name) { + if !template.generic_params.is_empty() { + return Err(format!( + "generic struct '{}' requires explicit type arguments", + name + )); + } + } + + Ok(WaveType::Struct(name.to_string())) +} + +fn rewrite_struct_name_usage( + name: &str, + subst: &HashMap, + env: &mut GenericEnv, +) -> Result { + match rewrite_struct_type(name, subst, env)? { + WaveType::Struct(n) => Ok(n), + other => Err(format!( + "invalid struct literal target '{}': expected struct type, got {:?}", + name, other + )), + } +} + +fn ensure_struct_instance( + base: &str, + args: &[WaveType], + env: &mut GenericEnv, +) -> Result { + let template = env + .struct_templates + .get(base) + .cloned() + .ok_or_else(|| format!("unknown generic struct template '{}'", base))?; + + if template.generic_params.len() != args.len() { + return Err(format!( + "generic struct '{}' expects {} type arguments, got {}", + base, + template.generic_params.len(), + args.len() + )); + } + + let inst_name = mangle_instance_name(base, args); + if env.struct_instances.contains_key(&inst_name) { + return Ok(inst_name); + } + if env.struct_in_progress.contains(&inst_name) { + return Ok(inst_name); + } + + let mut map: HashMap = HashMap::new(); + for (k, v) in template.generic_params.iter().zip(args.iter()) { + map.insert(k.clone(), v.clone()); + } + + env.struct_in_progress.insert(inst_name.clone()); + + let mut instantiated = template; + instantiated.name = inst_name.clone(); + instantiated.generic_params.clear(); + instantiated = rewrite_struct(instantiated, &map, env)?; + + env.struct_in_progress.remove(&inst_name); + env.struct_instances + .insert(inst_name.clone(), instantiated) + .map(|_| ()) + .unwrap_or(()); + + Ok(inst_name) +} + +fn ensure_function_instance( + base: &str, + args: &[WaveType], + env: &mut GenericEnv, +) -> Result { + let template = env + .function_templates + .get(base) + .cloned() + .ok_or_else(|| format!("unknown generic function template '{}'", base))?; + + if template.generic_params.len() != args.len() { + return Err(format!( + "generic function '{}' expects {} type arguments, got {}", + base, + template.generic_params.len(), + args.len() + )); + } + + let inst_name = mangle_instance_name(base, args); + if env.function_instances.contains_key(&inst_name) { + return Ok(inst_name); + } + if env.function_in_progress.contains(&inst_name) { + return Ok(inst_name); + } + + let mut map: HashMap = HashMap::new(); + for (k, v) in template.generic_params.iter().zip(args.iter()) { + map.insert(k.clone(), v.clone()); + } + + env.function_in_progress.insert(inst_name.clone()); + + let mut instantiated = template; + instantiated.name = inst_name.clone(); + instantiated.generic_params.clear(); + instantiated = rewrite_function(instantiated, &map, env)?; + + env.function_in_progress.remove(&inst_name); + env.function_instances + .insert(inst_name.clone(), instantiated) + .map(|_| ()) + .unwrap_or(()); + + Ok(inst_name) +} + +fn parse_type_application(name: &str) -> Result)>, String> { + let s = name.trim(); + let Some(lt) = s.find('<') else { + return Ok(None); + }; + if !s.ends_with('>') { + return Err(format!( + "malformed generic type '{}': missing closing '>'", + s + )); + } + + let base = s[..lt].trim(); + if base.is_empty() { + return Err(format!("malformed generic type '{}': missing base name", s)); + } + let inner = &s[lt + 1..s.len() - 1]; + let args = split_top_level_generic_args(inner).ok_or_else(|| { + format!( + "malformed generic type '{}': invalid generic argument list", + name + ) + })?; + Ok(Some((base.to_string(), args))) +} + +fn parse_wave_type_from_str(raw: &str) -> Result { + let tt = parse_type(raw).ok_or_else(|| format!("invalid type syntax '{}'", raw))?; + token_type_to_wave_type(&tt) + .ok_or_else(|| format!("unsupported type '{}' in generic argument", raw)) +} + +fn mangle_instance_name(base: &str, args: &[WaveType]) -> String { + let mut out = String::with_capacity(base.len() + 16); + out.push_str(base); + out.push_str("$g"); + for arg in args { + out.push('$'); + out.push_str(&mangle_type(arg)); + } + out +} + +fn mangle_type(ty: &WaveType) -> String { + match ty { + WaveType::Int(n) => format!("i{}", n), + WaveType::Uint(n) => format!("u{}", n), + WaveType::Float(n) => format!("f{}", n), + WaveType::Bool => "bool".to_string(), + WaveType::Char => "char".to_string(), + WaveType::Byte => "byte".to_string(), + WaveType::String => "str".to_string(), + WaveType::Void => "void".to_string(), + WaveType::Pointer(inner) => format!("p_{}", mangle_type(inner)), + WaveType::Array(inner, n) => format!("a{}_{}", n, mangle_type(inner)), + WaveType::Struct(name) => sanitize_ident(name), + } +} + +fn sanitize_ident(raw: &str) -> String { + let mut out = String::with_capacity(raw.len()); + for c in raw.chars() { + if c.is_ascii_alphanumeric() || c == '_' { + out.push(c); + } else { + out.push('_'); + } + } + out +} diff --git a/front/parser/src/lib.rs b/front/parser/src/lib.rs index 18ec33e3..e7af5745 100644 --- a/front/parser/src/lib.rs +++ b/front/parser/src/lib.rs @@ -20,6 +20,7 @@ macro_rules! println { pub mod ast; pub mod expr; pub mod format; +pub mod generics; pub mod import; pub mod parser; pub mod stdlib; diff --git a/front/parser/src/parser/decl.rs b/front/parser/src/parser/decl.rs index a822b056..d124e4f9 100644 --- a/front/parser/src/parser/decl.rs +++ b/front/parser/src/parser/decl.rs @@ -86,6 +86,7 @@ pub fn parse_variable_decl( Mutability::Let }; + skip_ws(tokens); if !is_const { if let Some(Token { token_type: TokenType::Mut, @@ -97,6 +98,7 @@ pub fn parse_variable_decl( } } + skip_ws(tokens); let name = match tokens.next() { Some(Token { token_type: TokenType::Identifier(name), @@ -111,11 +113,13 @@ pub fn parse_variable_decl( } }; + skip_ws(tokens); if !matches!(tokens.next().map(|t| &t.token_type), Some(TokenType::Colon)) { println!("Expected ':' after identifier"); return None; } + skip_ws(tokens); let type_token = match tokens.next() { Some(token) => token.clone(), _ => { @@ -167,6 +171,7 @@ pub fn parse_variable_decl( } }; + skip_ws(tokens); let initial_value = if let Some(Token { token_type: TokenType::Equal, .. @@ -227,6 +232,7 @@ pub fn parse_let(tokens: &mut Peekable>) -> Option { pub fn parse_var(tokens: &mut Peekable>) -> Option { let mutability = Mutability::Var; + skip_ws(tokens); let name = match tokens.next() { Some(Token { token_type: TokenType::Identifier(name), @@ -238,11 +244,13 @@ pub fn parse_var(tokens: &mut Peekable>) -> Option { } }; + skip_ws(tokens); if !matches!(tokens.next().map(|t| &t.token_type), Some(TokenType::Colon)) { println!("Expected ':' after identifier"); return None; } + skip_ws(tokens); let type_token = match tokens.next() { Some(token) => token.clone(), _ => { @@ -294,6 +302,7 @@ pub fn parse_var(tokens: &mut Peekable>) -> Option { } }; + skip_ws(tokens); let initial_value = if let Some(Token { token_type: TokenType::Equal, .. diff --git a/front/parser/src/parser/expr.rs b/front/parser/src/parser/expr.rs index e74c81c5..7543a7a3 100644 --- a/front/parser/src/parser/expr.rs +++ b/front/parser/src/parser/expr.rs @@ -54,7 +54,11 @@ pub fn parse_function_call( } } - Some(Expression::FunctionCall { name, args }) + Some(Expression::FunctionCall { + name, + type_args: Vec::new(), + args, + }) } pub fn parse_parentheses(tokens: &mut Peekable>) -> Vec { diff --git a/front/parser/src/parser/functions.rs b/front/parser/src/parser/functions.rs index f88deac7..8fb0c778 100644 --- a/front/parser/src/parser/functions.rs +++ b/front/parser/src/parser/functions.rs @@ -23,12 +23,87 @@ use std::collections::HashSet; use std::iter::Peekable; use std::slice::Iter; +fn skip_ws(tokens: &mut Peekable>) { + while matches!( + tokens.peek().map(|t| &t.token_type), + Some(TokenType::Whitespace | TokenType::Newline) + ) { + tokens.next(); + } +} + +pub fn parse_generic_param_names(tokens: &mut Peekable>) -> Option> { + skip_ws(tokens); + if !matches!( + tokens.peek().map(|t| &t.token_type), + Some(TokenType::Lchevr) + ) { + return Some(Vec::new()); + } + + tokens.next(); // consume '<' + let mut params: Vec = Vec::new(); + let mut seen: HashSet = HashSet::new(); + + loop { + skip_ws(tokens); + + if matches!( + tokens.peek().map(|t| &t.token_type), + Some(TokenType::Rchevr) + ) { + tokens.next(); // consume '>' + break; + } + + let ident = match tokens.next() { + Some(Token { + token_type: TokenType::Identifier(name), + .. + }) => name.clone(), + _ => { + println!("Error: Expected generic parameter name inside '<...>'"); + return None; + } + }; + + if !seen.insert(ident.clone()) { + println!("Error: Duplicate generic parameter '{}'", ident); + return None; + } + params.push(ident); + + skip_ws(tokens); + match tokens.peek().map(|t| &t.token_type) { + Some(TokenType::Comma) => { + tokens.next(); + } + Some(TokenType::Rchevr) => { + tokens.next(); // consume '>' + break; + } + _ => { + println!("Error: Expected ',' or '>' in generic parameter list"); + return None; + } + } + } + + Some(params) +} + pub fn parse_parameters(tokens: &mut Peekable>) -> Vec { let mut params = vec![]; - while tokens - .peek() - .map_or(false, |t| t.token_type != TokenType::Rparen) - { + loop { + skip_ws(tokens); + + if tokens + .peek() + .map_or(false, |t| t.token_type == TokenType::Rparen) + { + break; + } + let name = if let Some(Token { token_type: TokenType::Identifier(n), .. @@ -40,6 +115,7 @@ pub fn parse_parameters(tokens: &mut Peekable>) -> Vec>) -> Vec { tokens.next(); // consume ',' @@ -123,6 +200,8 @@ pub fn parse_parameters(tokens: &mut Peekable>) -> Vec>) -> Option { tokens.next(); + skip_ws(tokens); + let name = match tokens.next() { Some(Token { token_type: TokenType::Identifier(name), @@ -131,6 +210,9 @@ pub fn parse_function(tokens: &mut Peekable>) -> Option { _ => return None, }; + let generic_params = parse_generic_param_names(tokens)?; + + skip_ws(tokens); if tokens.peek()?.token_type != TokenType::Lparen { return None; } @@ -149,6 +231,7 @@ pub fn parse_function(tokens: &mut Peekable>) -> Option { } } + skip_ws(tokens); let return_type = if let Some(Token { token_type: TokenType::Arrow, .. @@ -160,9 +243,11 @@ pub fn parse_function(tokens: &mut Peekable>) -> Option { None }; + skip_ws(tokens); let body = extract_body(tokens)?; Some(ASTNode::Function(FunctionNode { name, + generic_params, parameters, body, return_type, diff --git a/front/parser/src/parser/items.rs b/front/parser/src/parser/items.rs index 7a29f030..b25d7d72 100644 --- a/front/parser/src/parser/items.rs +++ b/front/parser/src/parser/items.rs @@ -10,7 +10,7 @@ // SPDX-License-Identifier: MPL-2.0 use crate::ast::{ASTNode, ProtoImplNode, StatementNode, StructNode, WaveType}; -use crate::parser::functions::parse_function; +use crate::parser::functions::{parse_function, parse_generic_param_names}; use crate::types::parse_type_from_stream; use lexer::token::TokenType; use lexer::Token; @@ -151,6 +151,8 @@ pub fn parse_struct(tokens: &mut Peekable>) -> Option { } }; + let generic_params = parse_generic_param_names(tokens)?; + if tokens .peek() .map_or(true, |t| t.token_type != TokenType::Lbrace) @@ -296,6 +298,7 @@ pub fn parse_struct(tokens: &mut Peekable>) -> Option { Some(ASTNode::Struct(StructNode { name, + generic_params, fields, methods, })) diff --git a/front/parser/src/parser/types.rs b/front/parser/src/parser/types.rs index a694df04..c31849d7 100644 --- a/front/parser/src/parser/types.rs +++ b/front/parser/src/parser/types.rs @@ -15,6 +15,45 @@ use lexer::token::*; use lexer::Token; use std::iter::Peekable; +pub fn split_top_level_generic_args(inner: &str) -> Option> { + let mut parts: Vec = Vec::new(); + let mut depth: i32 = 0; + let mut start: usize = 0; + + for (i, c) in inner.char_indices() { + match c { + '<' => depth += 1, + '>' => { + depth -= 1; + if depth < 0 { + return None; + } + } + ',' if depth == 0 => { + let part = inner[start..i].trim(); + if part.is_empty() { + return None; + } + parts.push(part.to_string()); + start = i + 1; + } + _ => {} + } + } + + if depth != 0 { + return None; + } + + let tail = inner[start..].trim(); + if tail.is_empty() { + return None; + } + parts.push(tail.to_string()); + + Some(parts) +} + pub fn token_type_to_wave_type(token_type: &TokenType) -> Option { match token_type { TokenType::TypeVoid => Some(WaveType::Void), @@ -90,28 +129,19 @@ pub fn parse_type(type_str: &str) -> Option { return None; } - let base = &type_str[..lt_index]; + let base = type_str[..lt_index].trim(); + if base.is_empty() { + return None; + } let inner = &type_str[lt_index + 1..type_str.len() - 1]; if base == "array" { - let mut depth = 0; - let mut split_pos = None; - - for (i, c) in inner.char_indices() { - match c { - '<' => depth += 1, - '>' => depth -= 1, - ',' if depth == 0 => { - split_pos = Some(i); - break; - } - _ => {} - } + let args = split_top_level_generic_args(inner)?; + if args.len() != 2 { + return None; } - - let split_pos = split_pos?; - let elem_type_str = inner[..split_pos].trim(); - let size_str = inner[split_pos + 1..].trim(); + let elem_type_str = args[0].trim(); + let size_str = args[1].trim(); let elem_type = parse_type(elem_type_str)?; let size = size_str.parse::().ok()?; @@ -124,7 +154,11 @@ pub fn parse_type(type_str: &str) -> Option { return Some(TokenType::TypePointer(Box::new(inner_type))); } - return None; + let args = split_top_level_generic_args(inner)?; + for arg in &args { + let _ = parse_type(arg)?; + } + return Some(TokenType::TypeCustom(type_str.to_string())); } if type_str.starts_with('i') { diff --git a/front/parser/tests/parse_var_and_generics.rs b/front/parser/tests/parse_var_and_generics.rs new file mode 100644 index 00000000..5dde675a --- /dev/null +++ b/front/parser/tests/parse_var_and_generics.rs @@ -0,0 +1,50 @@ +use lexer::Lexer; +use parser::parse_syntax_only; + +fn parse_ok(src: &str) { + let mut lexer = Lexer::new(src); + let tokens = lexer.tokenize().expect("lex should succeed"); + let parsed = parse_syntax_only(&tokens); + if let Err(err) = parsed { + let mut dump = String::new(); + for (idx, t) in tokens.iter().enumerate() { + dump.push_str(&format!( + "{:03}: line={} {:?} lexeme=`{}`\n", + idx, t.line, t.token_type, t.lexeme + )); + } + panic!( + "parse failed: {:?}\nsource:\n{}\ntokens:\n{}", + err, src, dump + ); + } +} + +#[test] +fn parses_var_in_function_body() { + parse_ok( + r#" +fun main() { + var x: i32; + return; +} +"#, + ); +} + +#[test] +fn parses_multigeneric_function_and_types() { + parse_ok( + r#" +struct Pair { + first: A; + second: B; +} + +fun make_pair(a: A, b: B) -> Pair { + var pair_value: Pair; + return pair_value; +} +"#, + ); +} diff --git a/llvm/src/expression/rvalue/calls.rs b/llvm/src/expression/rvalue/calls.rs index c41cc833..411eff04 100644 --- a/llvm/src/expression/rvalue/calls.rs +++ b/llvm/src/expression/rvalue/calls.rs @@ -357,9 +357,17 @@ pub(crate) fn gen_method_call<'ctx, 'a>( pub(crate) fn gen_function_call<'ctx, 'a>( env: &mut ExprGenEnv<'ctx, 'a>, name: &str, + type_args: &[WaveType], args: &[Expression], expected_type: Option>, ) -> BasicValueEnum<'ctx> { + if !type_args.is_empty() { + panic!( + "generic call '{}<...>(...)' reached codegen without monomorphization", + name + ); + } + if let Some(info) = env.extern_c_info.get(name) { let function = env.module.get_function(&info.llvm_name).unwrap_or_else(|| { panic!( diff --git a/llvm/src/expression/rvalue/dispatch.rs b/llvm/src/expression/rvalue/dispatch.rs index 0101df42..7c852dc0 100644 --- a/llvm/src/expression/rvalue/dispatch.rs +++ b/llvm/src/expression/rvalue/dispatch.rs @@ -30,9 +30,11 @@ pub(crate) fn gen_expr<'ctx, 'a>( Expression::MethodCall { object, name, args } => { calls::gen_method_call(env, object, name, args) } - Expression::FunctionCall { name, args } => { - calls::gen_function_call(env, name, args, expected_type) - } + Expression::FunctionCall { + name, + type_args, + args, + } => calls::gen_function_call(env, name, type_args, args, expected_type), Expression::Cast { expr, target_type } => cast::gen(env, expr, target_type), Expression::AssignOperation { diff --git a/src/runner.rs b/src/runner.rs index 3d056447..3b6d4bb4 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -12,6 +12,7 @@ use crate::{DebugFlags, DepFlags, LinkFlags, LlvmFlags}; use ::error::*; use ::parser::ast::*; +use ::parser::generics::monomorphize_generics; use ::parser::import::*; use ::parser::verification::validate_program; use ::parser::*; @@ -1000,6 +1001,26 @@ pub(crate) unsafe fn run_wave_file( process::exit(1); } }; + let ast = match monomorphize_generics(ast) { + Ok(a) => a, + Err(msg) => { + WaveError::new( + WaveErrorKind::InvalidStatement(msg.clone()), + format!("generic monomorphization failed: {}", msg), + file_path.display().to_string(), + 1, + 1, + ) + .with_code("E3001") + .with_source_code(code.to_string()) + .with_context("generic instantiation") + .with_help( + "check generic type arguments, generic function calls, and generic struct usages", + ) + .display(); + process::exit(1); + } + }; validate_wave_ast_or_exit(file_path, &code, &ast); @@ -1117,6 +1138,23 @@ pub(crate) unsafe fn object_build_wave_file( e.display(); process::exit(1); }); + let ast = monomorphize_generics(ast).unwrap_or_else(|msg| { + WaveError::new( + WaveErrorKind::InvalidStatement(msg.clone()), + format!("generic monomorphization failed: {}", msg), + file_path.display().to_string(), + 1, + 1, + ) + .with_code("E3001") + .with_source_code(code.to_string()) + .with_context("generic instantiation") + .with_help( + "check generic type arguments, generic function calls, and generic struct usages", + ) + .display(); + process::exit(1); + }); validate_wave_ast_or_exit(file_path, &code, &ast); diff --git a/std/buffer/alloc.wave b/std/buffer/alloc.wave index 92a835bd..d0118615 100644 --- a/std/buffer/alloc.wave +++ b/std/buffer/alloc.wave @@ -88,3 +88,99 @@ fun buffer_reserve(buf: ptr, required_cap: i64) -> i64 { return 0; } + +fun tbuffer_new(elem_size: i64, initial_cap: i64) -> TypedBuffer { + if (elem_size <= 0) { + var empty: TypedBuffer; + empty.data = null; + empty.len = 0; + empty.cap_bytes = 0; + empty.elem_size = 0; + return empty; + } + + var cap_elems: i64 = initial_cap; + if (cap_elems <= 0) { + cap_elems = 16; + } + + var cap_bytes: i64 = cap_elems * elem_size; + var data: ptr = mem_alloc(cap_bytes); + + if (data == null) { + var failed: TypedBuffer; + failed.data = null; + failed.len = 0; + failed.cap_bytes = 0; + failed.elem_size = elem_size; + return failed; + } + + var created: TypedBuffer; + created.data = data; + created.len = 0; + created.cap_bytes = cap_bytes; + created.elem_size = elem_size; + return created; +} + +fun tbuffer_free(buf: ptr>) -> i64 { + var ret: i64 = 0; + + if (deref buf.cap_bytes > 0) { + ret = mem_free(deref buf.data, deref buf.cap_bytes); + } + + deref buf.data = null; + deref buf.len = 0; + deref buf.cap_bytes = 0; + + return ret; +} + +fun tbuffer_clear(buf: ptr>) { + deref buf.len = 0; +} + +fun tbuffer_reserve(buf: ptr>, required_len: i64) -> i64 { + if (deref buf.elem_size <= 0) { + return -1; + } + + if (required_len <= 0) { + return 0; + } + + var required_bytes: i64 = required_len * deref buf.elem_size; + if (required_bytes <= deref buf.cap_bytes) { + return 0; + } + + var new_cap_bytes: i64 = deref buf.cap_bytes; + if (new_cap_bytes <= 0) { + new_cap_bytes = deref buf.elem_size * 16; + } + + while (new_cap_bytes < required_bytes) { + new_cap_bytes = new_cap_bytes * 2; + } + + var new_data: ptr = mem_alloc(new_cap_bytes); + if (new_data == null) { + return -1; + } + + var used_bytes: i64 = deref buf.len * deref buf.elem_size; + if (used_bytes > 0) { + mem_copy(new_data, deref buf.data, used_bytes); + } + + if (deref buf.cap_bytes > 0) { + mem_free(deref buf.data, deref buf.cap_bytes); + } + + deref buf.data = new_data; + deref buf.cap_bytes = new_cap_bytes; + + return 0; +} diff --git a/std/buffer/read.wave b/std/buffer/read.wave index 479a1f22..fd069c85 100644 --- a/std/buffer/read.wave +++ b/std/buffer/read.wave @@ -18,3 +18,23 @@ fun buffer_at(buf: Buffer, index: i64) -> u8 { return buf.data[index]; } + +fun tbuffer_ptr(buf: TypedBuffer) -> ptr { + return buf.data as ptr; +} + +fun tbuffer_len(buf: TypedBuffer) -> i64 { + return buf.len; +} + +fun tbuffer_at(buf: TypedBuffer, index: i64, out_value: ptr) -> bool { + if (index < 0 || index >= buf.len) { + return false; + } + + var offset_bytes: i64 = index * buf.elem_size; + var slot: ptr = (buf.data + offset_bytes) as ptr; + deref out_value = deref slot; + + return true; +} diff --git a/std/buffer/types.wave b/std/buffer/types.wave index b27cf0fe..2d8094cc 100644 --- a/std/buffer/types.wave +++ b/std/buffer/types.wave @@ -14,3 +14,10 @@ struct Buffer { len: i64; cap: i64; } + +struct TypedBuffer { + data: ptr; + len: i64; + cap_bytes: i64; + elem_size: i64; +} diff --git a/std/buffer/write.wave b/std/buffer/write.wave index f14adf77..39dc1d79 100644 --- a/std/buffer/write.wave +++ b/std/buffer/write.wave @@ -77,3 +77,32 @@ fun buffer_set(buf: ptr, index: i64, value: u8) -> bool { deref p[index] = value; return true; } + +fun tbuffer_push(buf: ptr>, value: T) -> i64 { + var needed_len: i64 = deref buf.len + 1; + var ret: i64 = tbuffer_reserve(buf, needed_len); + + if (ret < 0) { + return ret; + } + + var offset_bytes: i64 = deref buf.len * deref buf.elem_size; + var slot: ptr = (deref buf.data + offset_bytes) as ptr; + + deref slot = value; + deref buf.len = needed_len; + + return 0; +} + +fun tbuffer_set(buf: ptr>, index: i64, value: T) -> bool { + if (index < 0 || index >= deref buf.len) { + return false; + } + + var offset_bytes: i64 = index * deref buf.elem_size; + var slot: ptr = (deref buf.data + offset_bytes) as ptr; + deref slot = value; + + return true; +} diff --git a/std/env/environ.wave b/std/env/environ.wave index 24d2fa28..ae70e494 100644 --- a/std/env/environ.wave +++ b/std/env/environ.wave @@ -12,6 +12,33 @@ import("std::sys::env"); import("std::env::parse"); +struct EnvResult { + ok: bool; + value: T; +} + +fun env_result_ok(value: T) -> EnvResult { + var result: EnvResult; + result.ok = true; + result.value = value; + return result; +} + +fun env_result_err(fallback: T) -> EnvResult { + var result: EnvResult; + result.ok = false; + result.value = fallback; + return result; +} + +fun env_unwrap_or(result: EnvResult, default_value: T) -> T { + if (result.ok) { + return result.value; + } + + return default_value; +} + fun env_get(name: str, dst: ptr, dst_cap: i64) -> i64 { var key_len: i64 = _env_key_len(name); @@ -77,38 +104,46 @@ fun env_exists(name: str) -> bool { return false; } -fun env_get_i32_default(name: str, default_value: i32) -> i32 { +fun env_get_i64(name: str) -> EnvResult { var raw: array; var n: i64 = env_get(name, &raw[0], 64); if (n <= 0) { - return default_value; + return env_result_err(0); } - var i: i64 = 0; - var sign: i32 = 1; - - if (raw[0] == 45) { - sign = -1; - i = 1; + var parsed: i64 = 0; + if (!_env_parse_i64(&raw[0], &parsed)) { + return env_result_err(0); } - var value: i32 = 0; + return env_result_ok(parsed); +} - while (raw[i] != 0) { - var c: u8 = raw[i]; - if (c < 48 || c > 57) { - return default_value; - } +fun env_get_i32(name: str) -> EnvResult { + var raw: array; + var n: i64 = env_get(name, &raw[0], 64); + + if (n <= 0) { + return env_result_err(0); + } - var digit: i32 = c - 48; - value = (value * 10) + digit; - i += 1; + var parsed_i64: i64 = 0; + if (!_env_parse_i64(&raw[0], &parsed_i64)) { + return env_result_err(0); } - if (sign < 0) { - return -value; + if (parsed_i64 < -2147483648 || parsed_i64 > 2147483647) { + return env_result_err(0); } - return value; + return env_result_ok(parsed_i64 as i32); +} + +fun env_get_i32_default(name: str, default_value: i32) -> i32 { + return env_unwrap_or(env_get_i32(name), default_value); +} + +fun env_get_i64_default(name: str, default_value: i64) -> i64 { + return env_unwrap_or(env_get_i64(name), default_value); } diff --git a/std/env/parse.wave b/std/env/parse.wave index 5fd9b6a3..88238d06 100644 --- a/std/env/parse.wave +++ b/std/env/parse.wave @@ -61,3 +61,38 @@ fun _env_copy_value( deref dst[value_len] = 0; return value_len; } + +fun _env_parse_i64(raw: ptr, out_value: ptr) -> bool { + var i: i64 = 0; + var sign: i64 = 1; + + if (raw[0] == 45) { + sign = -1; + i = 1; + } else if (raw[0] == 43) { + i = 1; + } + + if (raw[i] == 0) { + return false; + } + + var value: i64 = 0; + while (raw[i] != 0) { + var c: u8 = raw[i]; + if (c < 48 || c > 57) { + return false; + } + + var digit: i64 = c - 48; + value = (value * 10) + digit; + i += 1; + } + + if (sign < 0) { + value = -value; + } + + deref out_value = value; + return true; +} diff --git a/std/math/float.wave b/std/math/float.wave index a2e3e764..6e3b2141 100644 --- a/std/math/float.wave +++ b/std/math/float.wave @@ -9,38 +9,20 @@ // // SPDX-License-Identifier: MPL-2.0 -fun abs_f32(x: f32) -> f32 { - if (x < 0.0) { - return -x; - } +import("std::math::int"); - return x; +fun abs_f32(x: f32) -> f32 { + return num_abs(x, 0.0); } fun min_f32(a: f32, b: f32) -> f32 { - if (a < b) { - return a; - } - - return b; + return num_min(a, b); } fun max_f32(a: f32, b: f32) -> f32 { - if (a > b) { - return a; - } - - return b; + return num_max(a, b); } fun clamp_f32(x: f32, lo: f32, hi: f32) -> f32 { - if (x < lo) { - return lo; - } - - if (x > hi) { - return hi; - } - - return x; + return num_clamp(x, lo, hi); } diff --git a/std/math/int.wave b/std/math/int.wave index c7093af2..efd1c55b 100644 --- a/std/math/int.wave +++ b/std/math/int.wave @@ -9,15 +9,15 @@ // // SPDX-License-Identifier: MPL-2.0 -fun abs(x: i32) -> i32 { - if (x < 0) { +fun num_abs(x: T, zero: T) -> T { + if (x < zero) { return -x; } return x; } -fun min(a: i32, b: i32) -> i32 { +fun num_min(a: T, b: T) -> T { if (a < b) { return a; } @@ -25,7 +25,7 @@ fun min(a: i32, b: i32) -> i32 { return b; } -fun max(a: i32, b: i32) -> i32 { +fun num_max(a: T, b: T) -> T { if (a > b) { return a; } @@ -33,7 +33,7 @@ fun max(a: i32, b: i32) -> i32 { return b; } -fun clamp(x: i32, lo: i32, hi: i32) -> i32 { +fun num_clamp(x: T, lo: T, hi: T) -> T { if (x < lo) { return lo; } @@ -45,6 +45,29 @@ fun clamp(x: i32, lo: i32, hi: i32) -> i32 { return x; } +fun ptr_swap(a: ptr, b: ptr) { + let t: T = deref a; + + deref a = deref b; + deref b = t; +} + +fun abs(x: i32) -> i32 { + return num_abs(x, 0); +} + +fun min(a: i32, b: i32) -> i32 { + return num_min(a, b); +} + +fun max(a: i32, b: i32) -> i32 { + return num_max(a, b); +} + +fun clamp(x: i32, lo: i32, hi: i32) -> i32 { + return num_clamp(x, lo, hi); +} + fun sign(x: i32) -> i32 { if (x < 0) { return -1; @@ -98,8 +121,5 @@ fun div_floor_pos(a: i32, b: i32) -> i32 { } fun swap_i32(a: ptr, b: ptr) { - let t: i32 = deref a; - - deref a = deref b; - deref b = t; + ptr_swap(a, b); } diff --git a/std/math/trig.wave b/std/math/trig.wave index acd9608d..fc3dcdc9 100644 --- a/std/math/trig.wave +++ b/std/math/trig.wave @@ -9,15 +9,13 @@ // // SPDX-License-Identifier: MPL-2.0 +import("std::math::int"); + const MATH_PI_F64: f64 = 3.141592653589793; const MATH_TWO_PI_F64: f64 = 6.283185307179586; fun abs_f64(x: f64) -> f64 { - if (x < 0.0) { - return -x; - } - - return x; + return num_abs(x, 0.0); } fun wrap_angle_pi_f64(x: f64) -> f64 { diff --git a/std/mem/alloc.wave b/std/mem/alloc.wave index b412a56f..253e6d7b 100644 --- a/std/mem/alloc.wave +++ b/std/mem/alloc.wave @@ -10,6 +10,7 @@ // SPDX-License-Identifier: MPL-2.0 import("std::sys::memory"); +import("std::mem::ops"); fun mem_alloc(size: i64) -> ptr { return sys_alloc(size); @@ -32,3 +33,81 @@ fun mem_alloc_zeroed(size: i64) -> ptr { fun mem_free(p: ptr, size: i64) -> i64 { return sys_free(p, size); } + +fun mem_realloc(old_ptr: ptr, old_size: i64, new_size: i64) -> ptr { + if (new_size <= 0) { + if (old_ptr != null && old_size > 0) { + mem_free(old_ptr, old_size); + } + return null; + } + + if (old_ptr == null || old_size <= 0) { + return mem_alloc(new_size); + } + + var new_ptr: ptr = mem_alloc(new_size); + if (new_ptr == null) { + return null; + } + + var copy_size: i64 = old_size; + if (new_size < copy_size) { + copy_size = new_size; + } + mem_copy(new_ptr, old_ptr, copy_size); + + mem_free(old_ptr, old_size); + return new_ptr; +} + +fun mem_alloc_items(count: i64, elem_size: i64) -> ptr { + if (count <= 0 || elem_size <= 0) { + return null; + } + + return mem_alloc(count * elem_size) as ptr; +} + +fun mem_alloc_items_zeroed(count: i64, elem_size: i64) -> ptr { + if (count <= 0 || elem_size <= 0) { + return null; + } + + return mem_alloc_zeroed(count * elem_size) as ptr; +} + +fun mem_realloc_items( + old_ptr: ptr, + old_count: i64, + new_count: i64, + elem_size: i64 +) -> ptr { + if (elem_size <= 0) { + return null; + } + + var old_size: i64 = 0; + if (old_count > 0) { + old_size = old_count * elem_size; + } + + var new_size: i64 = 0; + if (new_count > 0) { + new_size = new_count * elem_size; + } + + return mem_realloc(old_ptr as ptr, old_size, new_size) as ptr; +} + +fun mem_free_items(p: ptr, count: i64, elem_size: i64) -> i64 { + if (p == null) { + return 0; + } + + if (count <= 0 || elem_size <= 0) { + return 0; + } + + return mem_free(p as ptr, count * elem_size); +} diff --git a/std/mem/cstr.wave b/std/mem/cstr.wave index be6bf70a..71329679 100644 --- a/std/mem/cstr.wave +++ b/std/mem/cstr.wave @@ -30,3 +30,57 @@ fun mem_copy_cstr(dst: ptr, s: str) -> i64 { deref dst[i] = 0; return i; } + +fun mem_copy_cstr_n(dst: ptr, dst_cap: i64, s: str) -> i64 { + if (dst_cap <= 0) { + return -1; + } + + var i: i64 = 0; + var max_copy: i64 = dst_cap - 1; + + while (s[i] != 0 && i < max_copy) { + deref dst[i] = s[i]; + i += 1; + } + + deref dst[i] = 0; + return i; +} + +fun mem_eq_cstr(a: str, b: str) -> bool { + var i: i64 = 0; + while (true) { + if (a[i] != b[i]) { + return false; + } + if (a[i] == 0) { + return true; + } + i += 1; + } + + return true; +} + +fun mem_starts_with_cstr(s: str, prefix: str) -> bool { + var i: i64 = 0; + while (prefix[i] != 0) { + if (s[i] != prefix[i]) { + return false; + } + i += 1; + } + return true; +} + +fun mem_find_cstr_char(s: str, c: u8) -> i64 { + var i: i64 = 0; + while (s[i] != 0) { + if (s[i] == c) { + return i; + } + i += 1; + } + return -1; +} diff --git a/std/mem/ops.wave b/std/mem/ops.wave index 3a102514..d9a00616 100644 --- a/std/mem/ops.wave +++ b/std/mem/ops.wave @@ -29,6 +29,31 @@ fun mem_copy(dst: ptr, src: ptr, size: i64) { } } +fun mem_move(dst: ptr, src: ptr, size: i64) { + if (size <= 0) { + return; + } + + if (dst == src) { + return; + } + + if ((dst as i64) < (src as i64)) { + var i: i64 = 0; + while (i < size) { + deref dst[i] = src[i]; + i += 1; + } + return; + } + + var i: i64 = size; + while (i > 0) { + i -= 1; + deref dst[i] = src[i]; + } +} + fun mem_cmp(a: ptr, b: ptr, size: i64) -> i32 { var i: i64 = 0; @@ -49,3 +74,64 @@ fun mem_cmp(a: ptr, b: ptr, size: i64) -> i32 { return 0; } + +fun mem_eq(a: ptr, b: ptr, size: i64) -> bool { + if (mem_cmp(a, b, size) == 0) { + return true; + } + + return false; +} + +fun mem_find_byte(src: ptr, size: i64, value: u8) -> i64 { + var i: i64 = 0; + while (i < size) { + if (src[i] == value) { + return i; + } + i += 1; + } + return -1; +} + +fun mem_swap(a: ptr, b: ptr) { + var tmp: T = deref a; + deref a = deref b; + deref b = tmp; +} + +fun mem_copy_items(dst: ptr, src: ptr, count: i64, elem_size: i64) { + if (count <= 0 || elem_size <= 0) { + return; + } + + mem_copy(dst as ptr, src as ptr, count * elem_size); +} + +fun mem_set_items(dst: ptr, value: T, count: i64, elem_size: i64) { + if (count <= 0 || elem_size <= 0) { + return; + } + + var i: i64 = 0; + while (i < count) { + deref dst[i] = value; + i += 1; + } +} + +fun mem_move_items(dst: ptr, src: ptr, count: i64, elem_size: i64) { + if (count <= 0 || elem_size <= 0) { + return; + } + + mem_move(dst as ptr, src as ptr, count * elem_size); +} + +fun mem_zero_items(dst: ptr, count: i64, elem_size: i64) { + if (count <= 0 || elem_size <= 0) { + return; + } + + mem_zero(dst as ptr, count * elem_size); +} diff --git a/std/net/address.wave b/std/net/address.wave new file mode 100644 index 00000000..a9680b85 --- /dev/null +++ b/std/net/address.wave @@ -0,0 +1,74 @@ +// This file is part of the Wave language project. +// Copyright (c) 2024-2026 Wave Foundation +// Copyright (c) 2024-2026 LunaStev and contributors +// +// This Source Code Form is subject to the terms of the +// Mozilla Public License, v. 2.0. +// If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +import("std::sys::socket"); + +struct NetAddrV4 { + ip: i32; // network byte order + port: i16; // network byte order +} + +struct NetSockAddrIn { + family: i16; + port: i16; + addr: i32; + zero: array; +} + +fun net_htons(x: i16) -> i16 { + return ((x & 255) << 8) | ((x >> 8) & 255); +} + +fun net_ntohs(x: i16) -> i16 { + return net_htons(x); +} + +fun net_htonl(x: i32) -> i32 { + return ((x & 0x000000FF) << 24) + | ((x & 0x0000FF00) << 8) + | ((x & 0x00FF0000) >> 8) + | ((x & 0xFF000000) >> 24); +} + +fun net_ntohl(x: i32) -> i32 { + return net_htonl(x); +} + +fun net_addr_v4(ip_host_order: i32, port_host_order: i16) -> NetAddrV4 { + var addr_value: NetAddrV4; + addr_value.ip = net_htonl(ip_host_order); + addr_value.port = net_htons(port_host_order); + return addr_value; +} + +fun net_addr_any_v4(port_host_order: i16) -> NetAddrV4 { + return net_addr_v4(0, port_host_order); +} + +fun net_addr_loopback_v4(port_host_order: i16) -> NetAddrV4 { + return net_addr_v4(0x7F000001, port_host_order); +} + +fun net_to_sockaddr_v4(addr: NetAddrV4) -> NetSockAddrIn { + return NetSockAddrIn { + family: AF_INET as i16, + port: addr.port, + addr: addr.ip, + zero: [0,0,0,0,0,0,0,0] + }; +} + +fun net_from_sockaddr_v4(sa: NetSockAddrIn) -> NetAddrV4 { + return NetAddrV4 { + ip: sa.addr, + port: sa.port + }; +} diff --git a/std/net/socket_base.wave b/std/net/socket_base.wave new file mode 100644 index 00000000..526e49fb --- /dev/null +++ b/std/net/socket_base.wave @@ -0,0 +1,133 @@ +// This file is part of the Wave language project. +// Copyright (c) 2024-2026 Wave Foundation +// Copyright (c) 2024-2026 LunaStev and contributors +// +// This Source Code Form is subject to the terms of the +// Mozilla Public License, v. 2.0. +// If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +import("std::sys::socket"); +import("std::sys::fs"); +import("std::net::address"); + +fun net_fd_valid(fd: i64) -> bool { + if (fd >= 0) { + return true; + } + return false; +} + +fun net_set_reuseaddr(fd: i64) -> i64 { + var one: i32 = 1; + return setsockopt( + fd, + SOL_SOCKET, + SO_REUSEADDR, + &one as ptr, + 4 + ); +} + +fun net_socket_tcp_v4() -> i64 { + return socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); +} + +fun net_socket_udp_v4() -> i64 { + return socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); +} + +fun net_bind_v4(fd: i64, addr: NetAddrV4) -> i64 { + var sa: NetSockAddrIn = net_to_sockaddr_v4(addr); + return bind(fd, &sa, 16); +} + +fun net_connect_v4(fd: i64, addr: NetAddrV4) -> i64 { + var sa: NetSockAddrIn = net_to_sockaddr_v4(addr); + return connect(fd, &sa, 16); +} + +fun net_accept_v4(fd: i64, out_addr: ptr) -> i64 { + var sa: NetSockAddrIn; + var salen: i32 = 16; + + var cfd: i64 = accept(fd, &sa, &salen); + if (cfd >= 0 && out_addr != null) { + deref out_addr = net_from_sockaddr_v4(sa); + } + + return cfd; +} + +fun net_send_all(fd: i64, buf: ptr, len: i64, flags: i32) -> i64 { + if (len <= 0) { + return 0; + } + + var sent_total: i64 = 0; + while (sent_total < len) { + var n: i64 = send(fd, buf + sent_total, len - sent_total, flags); + if (n <= 0) { + if (sent_total == 0) { + return n; + } + return sent_total; + } + sent_total += n; + } + return sent_total; +} + +fun net_recv_exact(fd: i64, buf: ptr, len: i64, flags: i32) -> i64 { + if (len <= 0) { + return 0; + } + + var read_total: i64 = 0; + while (read_total < len) { + var n: i64 = recv(fd, buf + read_total, len - read_total, flags); + if (n <= 0) { + if (read_total == 0) { + return n; + } + return read_total; + } + read_total += n; + } + return read_total; +} + +fun net_sendto_v4( + fd: i64, + addr: NetAddrV4, + buf: ptr, + len: i64, + flags: i32 +) -> i64 { + var sa: NetSockAddrIn = net_to_sockaddr_v4(addr); + return sendto(fd, buf, len, flags, &sa, 16); +} + +fun net_recvfrom_v4( + fd: i64, + buf: ptr, + len: i64, + flags: i32, + out_addr: ptr +) -> i64 { + var sa: NetSockAddrIn; + var salen: i32 = 16; + + var n: i64 = recvfrom(fd, buf, len, flags, &sa, &salen); + if (n >= 0 && out_addr != null) { + deref out_addr = net_from_sockaddr_v4(sa); + } + return n; +} + +fun net_shutdown_close(fd: i64) { + shutdown(fd, SHUT_RDWR); + close(fd); +} diff --git a/std/net/tcp.wave b/std/net/tcp.wave index 60e61d24..30f6ac4c 100644 --- a/std/net/tcp.wave +++ b/std/net/tcp.wave @@ -9,43 +9,15 @@ // // SPDX-License-Identifier: MPL-2.0 -// ======================================================= -// TCP networking for Wave -// ======================================================= -// -// Stream-based TCP API built on top of -// std::sys::socket -// -// Blocking, minimal, synchronous TCP abstraction. -// ======================================================= - import("std::sys::socket"); -import("std::sys::fs"); - - -// ----------------------- -// IPv4 address -// ----------------------- +import("std::net::address"); +import("std::net::socket_base"); struct TcpAddr { - ip: i32; // network byte order - port: i16; // network byte order -} - - -// sockaddr_in (internal) -struct SockAddrIn { - family: i16; // AF_INET - port: i16; - addr: i32; - zero: array; + ip: i32; // network byte order + port: i16; // network byte order } - -// ----------------------- -// TCP types -// ----------------------- - struct TcpListener { fd: i64; } @@ -54,84 +26,105 @@ struct TcpStream { fd: i64; } +fun _str_len(s: str) -> i64 { + var n: i64 = 0; + while (s[n] != 0) { + n += 1; + } + return n; +} + +fun _tcp_to_net_addr(addr: TcpAddr) -> NetAddrV4 { + var value: NetAddrV4; + value.ip = addr.ip; + value.port = addr.port; + return value; +} -// ----------------------- -// helpers -// ----------------------- +fun _tcp_from_net_addr(addr: NetAddrV4) -> TcpAddr { + var value: TcpAddr; + value.ip = addr.ip; + value.port = addr.port; + return value; +} fun htons(x: i16) -> i16 { - return ((x & 255) << 8) | ((x >> 8) & 255); + return net_htons(x); } fun htonl(x: i32) -> i32 { - return ((x & 0x000000FF) << 24) - | ((x & 0x0000FF00) << 8) - | ((x & 0x00FF0000) >> 8) - | ((x & 0xFF000000) >> 24); + return net_htonl(x); +} + +fun tcp_addr(ip_host_order: i32, port_host_order: i16) -> TcpAddr { + return _tcp_from_net_addr(net_addr_v4(ip_host_order, port_host_order)); } -fun _to_sockaddr(addr: TcpAddr) -> SockAddrIn { - return SockAddrIn { - family: AF_INET as i16, - port: addr.port, - addr: addr.ip, - zero: [0,0,0,0,0,0,0,0] - }; +fun tcp_addr_any(port_host_order: i16) -> TcpAddr { + return _tcp_from_net_addr(net_addr_any_v4(port_host_order)); } +fun tcp_addr_loopback(port_host_order: i16) -> TcpAddr { + return _tcp_from_net_addr(net_addr_loopback_v4(port_host_order)); +} -// ----------------------- -// listener -// ----------------------- +fun tcp_set_reuseaddr(fd: i64) -> i64 { + return net_set_reuseaddr(fd); +} fun tcp_bind(port: i16) -> TcpListener { - let fd: i64 = socket( - AF_INET, - SOCK_STREAM, - IPPROTO_TCP - ); - - let addr: SockAddrIn = SockAddrIn { - family: AF_INET as i16, - port: htons(port), - addr: 0, - zero: [0,0,0,0,0,0,0,0] - }; - - bind(fd, &addr, 16); - listen(fd, 128); + return tcp_bind_with_backlog(port, 128); +} + +fun tcp_bind_with_backlog(port: i16, backlog: i32) -> TcpListener { + var addr: TcpAddr = tcp_addr_any(port); + return tcp_bind_addr(addr, backlog); +} + +fun tcp_bind_addr(addr: TcpAddr, backlog: i32) -> TcpListener { + var fd: i64 = net_socket_tcp_v4(); + net_set_reuseaddr(fd); + + net_bind_v4(fd, _tcp_to_net_addr(addr)); + listen(fd, backlog); return TcpListener { fd: fd }; } fun tcp_accept(listener: TcpListener) -> TcpStream { - let fd: i64 = accept(listener.fd, null, null); - return TcpStream { fd: fd }; + var cfd: i64 = net_accept_v4(listener.fd, null); + return TcpStream { fd: cfd }; } -fun tcp_close_listener(listener: TcpListener) { - shutdown(listener.fd, SHUT_RDWR); - close(listener.fd); +fun tcp_accept_addr(listener: TcpListener, src: ptr) -> TcpStream { + var peer: NetAddrV4; + var cfd: i64 = net_accept_v4(listener.fd, &peer); + if (cfd >= 0) { + deref src = _tcp_from_net_addr(peer); + } + return TcpStream { fd: cfd }; } - -// ----------------------- -// stream (server + client common) -// ----------------------- +fun tcp_close_listener(listener: TcpListener) { + net_shutdown_close(listener.fd); +} fun tcp_connect(addr: TcpAddr) -> TcpStream { - let fd: i64 = socket( - AF_INET, - SOCK_STREAM, - IPPROTO_TCP - ); - - let sa: SockAddrIn = _to_sockaddr(addr); - connect(fd, &sa, 16); - + var fd: i64 = net_socket_tcp_v4(); + net_connect_v4(fd, _tcp_to_net_addr(addr)); return TcpStream { fd: fd }; } +fun tcp_try_connect(addr: TcpAddr) -> i64 { + var fd: i64 = net_socket_tcp_v4(); + var r: i64 = net_connect_v4(fd, _tcp_to_net_addr(addr)); + if (r < 0) { + close(fd); + return r; + } + return fd; +} + fun tcp_read(stream: TcpStream, buf: ptr, len: i64) -> i64 { return recv(stream.fd, buf, len, 0); } @@ -140,7 +133,38 @@ fun tcp_write(stream: TcpStream, buf: ptr, len: i64) -> i64 { return send(stream.fd, buf, len, 0); } +fun tcp_write_all(stream: TcpStream, buf: ptr, len: i64) -> i64 { + return net_send_all(stream.fd, buf, len, 0); +} + +fun tcp_read_exact(stream: TcpStream, buf: ptr, len: i64) -> i64 { + return net_recv_exact(stream.fd, buf, len, 0); +} + +fun tcp_write_str(stream: TcpStream, s: str) -> i64 { + var n: i64 = _str_len(s); + if (n <= 0) { + return 0; + } + return tcp_write_all(stream, s as ptr, n); +} + +fun tcp_from_fd(fd: i64) -> TcpStream { + return TcpStream { fd: fd }; +} + +fun tcp_listener_from_fd(fd: i64) -> TcpListener { + return TcpListener { fd: fd }; +} + +fun tcp_stream_valid(stream: TcpStream) -> bool { + return net_fd_valid(stream.fd); +} + +fun tcp_listener_valid(listener: TcpListener) -> bool { + return net_fd_valid(listener.fd); +} + fun tcp_close(stream: TcpStream) { - shutdown(stream.fd, SHUT_RDWR); - close(stream.fd); + net_shutdown_close(stream.fd); } diff --git a/std/net/udp.wave b/std/net/udp.wave index eb01b1a2..25eade45 100644 --- a/std/net/udp.wave +++ b/std/net/udp.wave @@ -9,123 +9,99 @@ // // SPDX-License-Identifier: MPL-2.0 -// ======================================================= -// UDP networking for Wave -// ======================================================= -// -// Datagram-based UDP API built on top of -// std::sys::socket -// -// No connection abstraction. -// No async. -// No buffering. -// ======================================================= - import("std::sys::socket"); -import("std::sys::fs"); - - -// ----------------------- -// IPv4 address -// ----------------------- +import("std::net::address"); +import("std::net::socket_base"); struct UdpAddr { - ip: i32; // network byte order - port: i16; // network byte order + ip: i32; // network byte order + port: i16; // network byte order } - -// sockaddr_in (internal use) -struct SockAddrIn { - family: i16; // AF_INET - port: i16; - addr: i32; - zero: array; -} - - -// ----------------------- -// UDP socket -// ----------------------- - struct UdpSocket { fd: i64; } +fun _str_len(s: str) -> i64 { + var n: i64 = 0; + while (s[n] != 0) { + n += 1; + } + return n; +} + +fun _udp_to_net_addr(addr: UdpAddr) -> NetAddrV4 { + var value: NetAddrV4; + value.ip = addr.ip; + value.port = addr.port; + return value; +} -// ----------------------- -// helpers -// ----------------------- +fun _udp_from_net_addr(addr: NetAddrV4) -> UdpAddr { + var value: UdpAddr; + value.ip = addr.ip; + value.port = addr.port; + return value; +} fun htons(x: i16) -> i16 { - return ((x & 255) << 8) | ((x >> 8) & 255); + return net_htons(x); } fun htonl(x: i32) -> i32 { - return ((x & 0x000000FF) << 24) - | ((x & 0x0000FF00) << 8) - | ((x & 0x00FF0000) >> 8) - | ((x & 0xFF000000) >> 24); + return net_htonl(x); } -fun _to_sockaddr(addr: UdpAddr) -> SockAddrIn { - return SockAddrIn { - family: AF_INET as i16, - port: addr.port, - addr: addr.ip, - zero: [0,0,0,0,0,0,0,0] - }; +fun udp_addr(ip_host_order: i32, port_host_order: i16) -> UdpAddr { + return _udp_from_net_addr(net_addr_v4(ip_host_order, port_host_order)); } +fun udp_addr_any(port_host_order: i16) -> UdpAddr { + return _udp_from_net_addr(net_addr_any_v4(port_host_order)); +} -// ----------------------- -// socket lifecycle -// ----------------------- - -fun udp_bind(port: i16) -> UdpSocket { - let fd: i64 = socket( - AF_INET, - SOCK_DGRAM, - IPPROTO_UDP - ); +fun udp_addr_loopback(port_host_order: i16) -> UdpAddr { + return _udp_from_net_addr(net_addr_loopback_v4(port_host_order)); +} - let addr: SockAddrIn = SockAddrIn { - family: AF_INET as i16, - port: htons(port), - addr: 0, - zero: [0,0,0,0,0,0,0,0] - }; +fun udp_set_reuseaddr(fd: i64) -> i64 { + return net_set_reuseaddr(fd); +} - bind(fd, &addr, 16); +fun udp_bind(port: i16) -> UdpSocket { + return udp_bind_addr(udp_addr_any(port)); +} +fun udp_bind_addr(addr: UdpAddr) -> UdpSocket { + var fd: i64 = net_socket_udp_v4(); + net_set_reuseaddr(fd); + net_bind_v4(fd, _udp_to_net_addr(addr)); return UdpSocket { fd: fd }; } fun udp_close(sock: UdpSocket) { - shutdown(sock.fd, SHUT_RDWR); - close(sock.fd); + net_shutdown_close(sock.fd); } - -// ----------------------- -// send / recv -// ----------------------- - fun udp_send_to( sock: UdpSocket, addr: UdpAddr, buf: ptr, len: i64 ) -> i64 { - let sa: SockAddrIn = _to_sockaddr(addr); - return sendto( - sock.fd, - buf, - len, - 0, - &sa, - 16 - ); + return net_sendto_v4(sock.fd, _udp_to_net_addr(addr), buf, len, 0); +} + +fun udp_send_str_to(sock: UdpSocket, addr: UdpAddr, s: str) -> i64 { + var n: i64 = _str_len(s); + if (n <= 0) { + return 0; + } + return udp_send_to(sock, addr, s as ptr, n); +} + +fun udp_recv(sock: UdpSocket, buf: ptr, len: i64) -> i64 { + return recv(sock.fd, buf, len, 0); } fun udp_recv_from( @@ -134,22 +110,18 @@ fun udp_recv_from( len: i64, src: ptr ) -> i64 { - var sa: SockAddrIn; - var salen: i32 = 16; - - let n: i64 = recvfrom( - sock.fd, - buf, - len, - 0, - &sa, - &salen - ); - - deref src = UdpAddr { - ip: sa.addr, - port: sa.port - }; - + var peer: NetAddrV4; + var n: i64 = net_recvfrom_v4(sock.fd, buf, len, 0, &peer); + if (n >= 0) { + deref src = _udp_from_net_addr(peer); + } return n; } + +fun udp_from_fd(fd: i64) -> UdpSocket { + return UdpSocket { fd: fd }; +} + +fun udp_socket_valid(sock: UdpSocket) -> bool { + return net_fd_valid(sock.fd); +} diff --git a/test/test91.wave b/test/test91.wave new file mode 100644 index 00000000..dbdbacee --- /dev/null +++ b/test/test91.wave @@ -0,0 +1,77 @@ +#[target(os="linux")] +fun platform_id() -> i32 { + return 1; +} + +#[target(os="macos")] +fun platform_id() -> i32 { + return 2; +} + +struct Box { + value: T; +} + +struct Pair { + first: A; + second: B; +} + +fun identity(x: T) -> T { + return x; +} + +fun wrap(x: T) -> Box { + var boxed: Box; + boxed.value = x; + return boxed; +} + +fun make_pair(a: A, b: B) -> Pair { + var pair_value: Pair; + pair_value.first = a; + pair_value.second = b; + return pair_value; +} + +fun first_of(p: Pair) -> A { + return p.first; +} + +fun main() -> i32 { + var a: i32 = identity(10); + if (a != 10) { + return 1; + } + + var boxed_i64: Box = wrap(1234); + if (boxed_i64.value != 1234) { + return 2; + } + + var p: Pair = make_pair(7, 99); + if (p.first != 7) { + return 3; + } + + if (p.second != 99) { + return 4; + } + + var nested: Box> = wrap>(wrap(42)); + if (nested.value.value != 42) { + return 5; + } + + var p_first: i32 = first_of(p); + if (p_first != 7) { + return 6; + } + + var os: i32 = platform_id(); + if (os != 1 && os != 2) { + return 7; + } + + return 0; +}