diff options
-rw-r--r-- | main.c | 3 | ||||
-rw-r--r-- | test.toc | 1 | ||||
-rw-r--r-- | types.c | 35 |
3 files changed, 28 insertions, 11 deletions
@@ -1,8 +1,7 @@ /* TODO: -named return values optional params -named params +named args evaluator (simplify compile time constant expressions) re-do cgen */ @@ -1,4 +1,5 @@ main @= fn() { + foo @= fn() i64 { return 3; }; test @= fn(x : i64, y : i32, z,w: i64) ret1 : i64, ret2 : i64 { ret1 = x; }; @@ -1,7 +1,8 @@ typedef struct { Array in_decls; /* array of declarations we are currently inside */ Block *block; - Type *ret_type; /* the return type of the function we're currently parsing. NULL for none. */ + bool can_ret; + Type ret_type; /* the return type of the function we're currently parsing. */ } Typer; static bool types_stmt(Typer *tr, Statement *s); @@ -308,8 +309,6 @@ static bool type_can_be_truthy(Type *t) { static bool types_expr(Typer *tr, Expression *e) { if (e->flags & EXPR_FLAG_FOUND_TYPE) return true; - Type *prev_ret_type = tr->ret_type; - Type *t = &e->type; t->flags = 0; t->kind = TYPE_UNKNOWN; /* default to unknown type (in the case of an error) */ @@ -317,12 +316,22 @@ static bool types_expr(Typer *tr, Expression *e) { bool success = true; switch (e->kind) { case EXPR_FN: { + Type prev_ret_type = tr->ret_type; + bool prev_can_ret = tr->can_ret; FnExpr *f = &e->fn; if (!type_of_fn(tr, f, t)) { success = false; goto fn_ret; } - tr->ret_type = t->fn.types.data; + bool has_named_ret_vals = e->fn.ret_decls.data != NULL; + if (has_named_ret_vals) { + /* set return type to void to not allow return values */ + tr->ret_type.kind = TYPE_VOID; + tr->ret_type.flags = 0; + } else { + tr->ret_type = *(Type *)t->fn.types.data; + } + tr->can_ret = true; arr_foreach(&f->params, Declaration, decl) add_ident_decls(&f->body, decl); arr_foreach(&f->ret_decls, Declaration, decl) @@ -354,7 +363,7 @@ static bool types_expr(Typer *tr, Expression *e) { success = false; goto fn_ret; } - } else if (ret_type->kind != TYPE_VOID) { + } else if (ret_type->kind != TYPE_VOID && !has_named_ret_vals) { Array stmts = e->fn.body.stmts; if (stmts.len) { Statement *last_stmt = (Statement *)stmts.data + (stmts.len - 1); @@ -376,6 +385,7 @@ static bool types_expr(Typer *tr, Expression *e) { } fn_ret: tr->ret_type = prev_ret_type; + tr->can_ret = prev_can_ret; if (!success) return false; } break; case EXPR_LITERAL_INT: @@ -779,19 +789,25 @@ static bool types_stmt(Typer *tr, Statement *s) { return false; break; case STMT_RET: - if (!tr->ret_type) { + if (!tr->can_ret) { err_print(s->where, "return outside of a function."); return false; } if (s->ret.flags & RET_FLAG_EXPR) { - if (tr->ret_type->kind == TYPE_VOID) { - err_print(s->where, "Return value in void function."); + if (tr->ret_type.kind == TYPE_VOID) { + err_print(s->where, "Return value in function which should not return a value."); return false; } if (!types_expr(tr, &s->ret.expr)) return false; + if (!type_eq(&tr->ret_type, &s->ret.expr.type)) { + char *got = type_to_str(&s->ret.expr.type); + char *expected = type_to_str(&tr->ret_type); + err_print(s->where, "Returning type %s in function which returns %s.", got, expected); + return false; + } } else { - if (tr->ret_type->kind != TYPE_VOID) { + if (tr->ret_type.kind != TYPE_VOID) { err_print(s->where, "No return value in non-void function."); return false; } @@ -803,6 +819,7 @@ static bool types_stmt(Typer *tr, Statement *s) { static void typer_create(Typer *tr) { tr->block = NULL; + tr->can_ret = false; arr_create(&tr->in_decls, sizeof(Declaration *)); } |