summaryrefslogtreecommitdiff
path: root/types.c
diff options
context:
space:
mode:
Diffstat (limited to 'types.c')
-rw-r--r--types.c35
1 files changed, 26 insertions, 9 deletions
diff --git a/types.c b/types.c
index 7ca8cf6..9fcd5f1 100644
--- a/types.c
+++ b/types.c
@@ -1,7 +1,8 @@
typedef struct {
Array in_decls; /* array of declarations we are currently inside */
Block *block;
- Type *ret_type; /* the return type of the function we're currently parsing. NULL for none. */
+ bool can_ret;
+ Type ret_type; /* the return type of the function we're currently parsing. */
} Typer;
static bool types_stmt(Typer *tr, Statement *s);
@@ -308,8 +309,6 @@ static bool type_can_be_truthy(Type *t) {
static bool types_expr(Typer *tr, Expression *e) {
if (e->flags & EXPR_FLAG_FOUND_TYPE) return true;
- Type *prev_ret_type = tr->ret_type;
-
Type *t = &e->type;
t->flags = 0;
t->kind = TYPE_UNKNOWN; /* default to unknown type (in the case of an error) */
@@ -317,12 +316,22 @@ static bool types_expr(Typer *tr, Expression *e) {
bool success = true;
switch (e->kind) {
case EXPR_FN: {
+ Type prev_ret_type = tr->ret_type;
+ bool prev_can_ret = tr->can_ret;
FnExpr *f = &e->fn;
if (!type_of_fn(tr, f, t)) {
success = false;
goto fn_ret;
}
- tr->ret_type = t->fn.types.data;
+ bool has_named_ret_vals = e->fn.ret_decls.data != NULL;
+ if (has_named_ret_vals) {
+ /* set return type to void to not allow return values */
+ tr->ret_type.kind = TYPE_VOID;
+ tr->ret_type.flags = 0;
+ } else {
+ tr->ret_type = *(Type *)t->fn.types.data;
+ }
+ tr->can_ret = true;
arr_foreach(&f->params, Declaration, decl)
add_ident_decls(&f->body, decl);
arr_foreach(&f->ret_decls, Declaration, decl)
@@ -354,7 +363,7 @@ static bool types_expr(Typer *tr, Expression *e) {
success = false;
goto fn_ret;
}
- } else if (ret_type->kind != TYPE_VOID) {
+ } else if (ret_type->kind != TYPE_VOID && !has_named_ret_vals) {
Array stmts = e->fn.body.stmts;
if (stmts.len) {
Statement *last_stmt = (Statement *)stmts.data + (stmts.len - 1);
@@ -376,6 +385,7 @@ static bool types_expr(Typer *tr, Expression *e) {
}
fn_ret:
tr->ret_type = prev_ret_type;
+ tr->can_ret = prev_can_ret;
if (!success) return false;
} break;
case EXPR_LITERAL_INT:
@@ -779,19 +789,25 @@ static bool types_stmt(Typer *tr, Statement *s) {
return false;
break;
case STMT_RET:
- if (!tr->ret_type) {
+ if (!tr->can_ret) {
err_print(s->where, "return outside of a function.");
return false;
}
if (s->ret.flags & RET_FLAG_EXPR) {
- if (tr->ret_type->kind == TYPE_VOID) {
- err_print(s->where, "Return value in void function.");
+ if (tr->ret_type.kind == TYPE_VOID) {
+ err_print(s->where, "Return value in function which should not return a value.");
return false;
}
if (!types_expr(tr, &s->ret.expr))
return false;
+ if (!type_eq(&tr->ret_type, &s->ret.expr.type)) {
+ char *got = type_to_str(&s->ret.expr.type);
+ char *expected = type_to_str(&tr->ret_type);
+ err_print(s->where, "Returning type %s in function which returns %s.", got, expected);
+ return false;
+ }
} else {
- if (tr->ret_type->kind != TYPE_VOID) {
+ if (tr->ret_type.kind != TYPE_VOID) {
err_print(s->where, "No return value in non-void function.");
return false;
}
@@ -803,6 +819,7 @@ static bool types_stmt(Typer *tr, Statement *s) {
static void typer_create(Typer *tr) {
tr->block = NULL;
+ tr->can_ret = false;
arr_create(&tr->in_decls, sizeof(Declaration *));
}