diff options
author | Leo Tenenbaum <pommicket@gmail.com> | 2019-09-06 22:17:38 -0400 |
---|---|---|
committer | Leo Tenenbaum <pommicket@gmail.com> | 2019-09-06 22:17:38 -0400 |
commit | f146ede613f0095a12b2fd0f756bae63b167abe2 (patch) | |
tree | 5ae5901af95318e6d016001d3d7384f61d625d94 | |
parent | 7097749bf41739feffbb7f6da20b2b951f73aa9d (diff) |
started block return values
-rw-r--r-- | base_cgen.c | 11 | ||||
-rwxr-xr-x | build.sh | 2 | ||||
-rw-r--r-- | cgen.c | 62 | ||||
-rw-r--r-- | eval.c | 4 | ||||
-rw-r--r-- | identifiers.c | 5 | ||||
-rw-r--r-- | out.c | 11 | ||||
-rw-r--r-- | out.h | 2 | ||||
-rw-r--r-- | parse.c | 223 | ||||
-rw-r--r-- | test.toc | 9 | ||||
-rw-r--r-- | types.c | 54 |
10 files changed, 226 insertions, 157 deletions
diff --git a/base_cgen.c b/base_cgen.c index 4d4bb3a..b8da495 100644 --- a/base_cgen.c +++ b/base_cgen.c @@ -8,6 +8,7 @@ typedef struct { FILE *c_out; FILE *h_out; unsigned long anon_fn_count; + unsigned long anon_var_count; Block *block; int indent_level; bool indent_next; /* should the next thing written be indented? */ @@ -123,7 +124,7 @@ static const char *builtin_type_to_str(BuiltinType b) { } /* will this function use a pointer parameter for output? (e.g. fn()[3]int => void(int (*x)[3]) */ -static bool fn_uses_out_param(Type *fn_ret_type) { +static bool cgen_fn_uses_out_param(Type *fn_ret_type) { switch (fn_ret_type->kind) { case TYPE_TUPLE: case TYPE_ARR: @@ -144,7 +145,7 @@ static void cgen_type_pre(CGenerator *g, Type *t) { case TYPE_FN: { Type *types = t->fn.types.data; Type *ret_type = &types[0]; - if (fn_uses_out_param(ret_type)) { + if (cgen_fn_uses_out_param(ret_type)) { cgen_write(g, "void "); } else { cgen_type_pre(g, ret_type); @@ -166,7 +167,7 @@ static void cgen_type_pre(CGenerator *g, Type *t) { static void cgen_type_post(CGenerator *g, Type *t); /* either pass NULL for param_types (x)or for params */ static void cgen_fn_params(CGenerator *g, Type *param_types, Param *params, size_t nparams, Type *ret_type) { - bool uses_out_param = fn_uses_out_param(ret_type); + bool uses_out_param = cgen_fn_uses_out_param(ret_type); cgen_write(g, "("); if (nparams) { @@ -214,7 +215,7 @@ static void cgen_type_post(CGenerator *g, Type *t) { Type *param_types = types + 1; assert(t->fn.types.len > 0); size_t nparams = t->fn.types.len-1; - bool uses_out_param = fn_uses_out_param(ret_type); + bool uses_out_param = cgen_fn_uses_out_param(ret_type); cgen_write(g, ")"); cgen_fn_params(g, param_types, NULL, nparams, ret_type); if (!uses_out_param) { @@ -256,7 +257,7 @@ static bool cgen_fn_header(CGenerator *g, FnExpr *f) { cgen_write(g, "static "); /* anonymous functions only exist in this translation unit */ } - bool uses_out_param = fn_uses_out_param(&f->ret_type); + bool uses_out_param = cgen_fn_uses_out_param(&f->ret_type); size_t nparams = f->params.len; if (uses_out_param) { cgen_write(g, "void "); @@ -1,5 +1,5 @@ #!/bin/bash -CC=gcc +CC=clang # Possible extra build flags # these are for compiling the compiler, and NOT for compiling the program itself. @@ -2,6 +2,7 @@ static void cgen_create(CGenerator *g, Identifiers *ids, FILE *c_out, FILE *h_ou g->c_out = c_out; g->h_out = h_out; g->anon_fn_count = 0; + g->anon_var_count = 0; g->indent_level = 0; g->block = NULL; g->indent_next = true; @@ -220,25 +221,60 @@ static bool cgen_stmt(CGenerator *g, Statement *s) { static bool cgen_fns_in_stmt(CGenerator *g, Statement *s); +typedef struct { + bool is_return; /* true => this is a function return */ + unsigned long var_no; /* if is_return = false, set the anonymous variable with this number to the return value. */ + const char *exit_with; /* how to exit this block in C, e.g. "break" (not needed if is_return = true). */ +} BlockExitKind; + +/* generates a block but not the functions, etc. inside it */ +static bool cgen_block(CGenerator *g, Block *b, BlockExitKind *exit_kind) { + bool success = true; + cgen_writeln(g, "{"); + g->indent_level++; + arr_foreach(&b->stmts, Statement, s) { + if (!cgen_stmt(g, s)) + success = false; + } + if (exit_kind->is_return) { + /* generate return from function */ + if (b->ret_expr && cgen_fn_uses_out_param(&b->ret_expr->type)) { + cgen_write(g, "*out__ = "); + cgen_expr(g, b->ret_expr); + cgen_writeln(g, ";"); + cgen_writeln(g, "return;"); + } else { + cgen_write(g, "return"); + if (b->ret_expr) { + cgen_write(g, " "); + cgen_expr(g, b->ret_expr); + } + cgen_writeln(g, ";"); + } + } else { + err_print(b->ret_expr->where, "TODO"); + return false; + } + if (success) { + g->indent_level--; + cgen_writeln(g, "}"); + } + return success; +} + /* Generates function definition, and the definitions of all functions inside this */ static bool cgen_fn(CGenerator *g, FnExpr *f) { if (!cgen_fn_header(g, f)) return false; Block *prev_block = g->block; cgen_block_enter(g, &f->body); - bool ret = true; cgen_write_space(g); - cgen_writeln(g, "{"); - g->indent_level++; - arr_foreach(&f->body.stmts, Statement, s) { - if (!cgen_stmt(g, s)) - ret = false; - } - g->indent_level--; - cgen_writeln(g, "}"); - if (ret) { - arr_foreach(&f->body.stmts, Statement, stmt) { - if (!cgen_fns_in_stmt(g, stmt)) ret = false; - } + BlockExitKind e_kind; + e_kind.is_return = 1; + if (!cgen_block(g, &f->body, &e_kind)) return false; + + bool ret = true; + arr_foreach(&f->body.stmts, Statement, stmt) { + if (!cgen_fns_in_stmt(g, stmt)) ret = false; } cgen_block_exit(g, prev_block); return ret; @@ -80,10 +80,10 @@ static bool eval_expr_as_int(Expression *e, Integer *i) { return false; } if (d->type.kind != TYPE_BUILTIN || !type_builtin_is_integer(d->type.builtin)) { - char type_str[128]; - type_to_str(&d->type, type_str, sizeof type_str); + char *type_str = type_to_str(&d->type); err_print(e->where, "Expected integer, but identifier has type %s.", type_str); info_print(d->where, "Declaration was here."); + free(type_str); return false; } /* TODO: tuples */ diff --git a/identifiers.c b/identifiers.c index 1cb543d..eb82c7b 100644 --- a/identifiers.c +++ b/identifiers.c @@ -62,6 +62,8 @@ static Identifier ident_new(Identifiers *ids, Identifier parent, unsigned char i tree->children[i] = NULL; #endif tree->parent = parent; + if (parent) + tree->depth = parent->depth + 1; tree->index_in_parent = index_in_parent; return tree; } @@ -163,8 +165,7 @@ static char *ident_to_str(Identifier i) { char *str = malloc(i_len + 1); str += i_len; *str = 0; - - while (i) { + while (i->parent) { str--; unsigned char c_high = i->index_in_parent; unsigned char c_low = i->parent->index_in_parent; @@ -2,12 +2,15 @@ /* toc */ #include <stdio.h> -void foo__bar(void) { - puts("Hello!"); -; +void foo(int64_t (*out__)[3]) { + int64_t x[3] = {0}; + *out__ = x; + return; } void main__(void) { - foo__bar(); + int64_t x[3] = foo(); + printf("Foo: %ld\n", (long)x); + return; } int main(void) { @@ -1,4 +1,4 @@ #include <stddef.h> #include <stdint.h> -void foo__bar(void); +void foo(int64_t (*out__)[3]); void main__(void); @@ -52,6 +52,7 @@ typedef struct { typedef struct Block { Array stmts; + struct Expression *ret_expr; /* the return expression of this block, e.g. {foo(); 3} => 3 NULL for no expression. */ } Block; typedef struct { @@ -91,9 +92,6 @@ typedef struct { Array args; /* of Expression */ } DirectExpr; - -#define EXPR_FLAG_FLEXIBLE 0x01 /* e.g. 4 => float/i32/etc. */ - typedef struct Expression { Location where; ExprKind kind; @@ -139,10 +137,12 @@ typedef enum { STMT_DECL, STMT_EXPR } StatementKind; - + +#define STMT_FLAG_VOIDED_EXPR 0x01 /* the "4;" in fn () { 4; } is a voided expression, but the "4" in fn () int { 4 } is not */ typedef struct { Location where; StatementKind kind; + unsigned short flags; union { Declaration decl; Expression expr; @@ -237,7 +237,7 @@ static Keyword builtin_type_to_kw(BuiltinType t) { } /* returns the number of characters written, not including the null character */ -static size_t type_to_str(Type *t, char *buffer, size_t bufsize) { +static size_t type_to_str_(Type *t, char *buffer, size_t bufsize) { switch (t->kind) { case TYPE_VOID: return str_copy(buffer, bufsize, "void"); @@ -256,12 +256,12 @@ static size_t type_to_str(Type *t, char *buffer, size_t bufsize) { for (size_t i = 0; i < nparams; i++) { if (i > 0) written += str_copy(buffer + written, bufsize - written, ", "); - written += type_to_str(¶m_types[i], buffer + written, bufsize - written); + written += type_to_str_(¶m_types[i], buffer + written, bufsize - written); } written += str_copy(buffer + written, bufsize - written, ")"); if (ret_type->kind != TYPE_VOID) { written += str_copy(buffer + written, bufsize - written, " "); - written += type_to_str(ret_type, buffer + written, bufsize - written); + written += type_to_str_(ret_type, buffer + written, bufsize - written); } return written; } break; @@ -274,7 +274,7 @@ static size_t type_to_str(Type *t, char *buffer, size_t bufsize) { written += str_copy(buffer + written, bufsize - written, "N"); } written += str_copy(buffer + written, bufsize - written, "]"); - written += type_to_str(t->arr.of, buffer + written, bufsize - written); + written += type_to_str_(t->arr.of, buffer + written, bufsize - written); return written; } break; case TYPE_TUPLE: { @@ -282,7 +282,7 @@ static size_t type_to_str(Type *t, char *buffer, size_t bufsize) { arr_foreach(&t->tuple, Type, child) { if (child != t->tuple.data) written += str_copy(buffer + written, bufsize - written, ", "); - written += type_to_str(child, buffer + written, bufsize - written); + written += type_to_str_(child, buffer + written, bufsize - written); } written += str_copy(buffer + written, bufsize - written, ")"); return written; @@ -293,6 +293,14 @@ static size_t type_to_str(Type *t, char *buffer, size_t bufsize) { return 0; } +/* return value should be freed by caller */ +static char *type_to_str(Type *t) { + /* TODO allow types >255 chars */ + char *ret = err_malloc(256); + type_to_str_(t, ret, 256); + return ret; +} + /* allocate a new expression. */ @@ -316,99 +324,66 @@ static int op_precedence(Keyword op) { } } +/* TODO: check that we check which thing ends it everywhere */ -/* - ends_with = which keyword does this expression end with? - if it's KW_RPAREN, this will match parentheses properly. -*/ -typedef enum { - EXPR_END_RPAREN_OR_COMMA, - EXPR_END_RSQUARE, - EXPR_END_SEMICOLON -} ExprEndKind; -static Token *expr_find_end(Parser *p, ExprEndKind ends_with) { +#define EXPR_CAN_END_WITH_COMMA 0x01 /* a comma could end the expression */ + +static Token *expr_find_end(Parser *p, unsigned flags) { Tokenizer *t = p->tokr; - int bracket_level = 0; /* if ends_with = EXPR_END_RSQUARE, used for square brackets, - if ends_with = EXPR_END_RPAREN_OR_COMMA, used for parens */ + int paren_level = 0; int brace_level = 0; + int square_level = 0; Token *token = t->token; while (1) { - switch (ends_with) { - case EXPR_END_RPAREN_OR_COMMA: - if (token->kind == TOKEN_KW) { - switch (token->kw) { - case KW_COMMA: - if (bracket_level == 0) - return token; - break; - case KW_LPAREN: - bracket_level++; - break; - case KW_RPAREN: - bracket_level--; - if (bracket_level < 0) - return token; - break; - default: break; - } - } - break; - case EXPR_END_RSQUARE: - if (token->kind == TOKEN_KW) { - switch (token->kw) { - case KW_LSQUARE: - bracket_level++; - break; - case KW_RSQUARE: - bracket_level--; - if (bracket_level < 0) - return token; - break; - default: break; - } - } - break; - case EXPR_END_SEMICOLON: - if (token->kind == TOKEN_KW) { - switch (token->kw) { - case KW_SEMICOLON: - /* ignore semicolons inside braces {} */ - if (brace_level == 0) - return token; - break; - case KW_LBRACE: - brace_level++; - break; - case KW_RBRACE: - brace_level--; - if (brace_level < 0) { - t->token = token; - tokr_err(t, "Closing '}' without matching opening '{'."); - return NULL; - } - break; - default: break; - } + if (token->kind == TOKEN_KW) { + switch (token->kw) { + case KW_COMMA: + if ((flags & EXPR_CAN_END_WITH_COMMA) && + paren_level == 0 && brace_level == 0 && square_level == 0) + return token; + break; + case KW_LPAREN: + paren_level++; + break; + case KW_RPAREN: + paren_level--; + if (paren_level < 0) + return token; + break; + case KW_LSQUARE: + square_level++; + break; + case KW_RSQUARE: + square_level--; + if (square_level < 0) + return token; + break; + case KW_LBRACE: + brace_level++; + break; + case KW_RBRACE: + brace_level--; + if (brace_level < 0) + return token; + break; + case KW_SEMICOLON: + if (brace_level == 0) + return token; + break; + default: break; } - break; } if (token->kind == TOKEN_EOF) { - switch (ends_with) { - case EXPR_END_SEMICOLON: - if (brace_level > 0) { - tokr_err(t, "Opening brace was never closed."); /* FEATURE: Find out where this is */ - return NULL; - } else { - tokr_err(t, "Could not find ';' at end of expression."); - return NULL; - } - case EXPR_END_RPAREN_OR_COMMA: - tokr_err(t, "Opening ( was never closed."); - return NULL; - case EXPR_END_RSQUARE: - tokr_err(t, "Opening [ was never closed."); - return NULL; + if (brace_level > 0) { + tokr_err(t, "Opening brace { was never closed."); /* FEATURE: Find out where this is */ + } else if (paren_level > 0) { + tokr_err(t, "Opening parenthesis ( was never closed."); + } else if (square_level > 0) { + tokr_err(t, "Opening square bracket [ was never closed."); + } else { + tokr_err(t, "Could not find end of expression."); } + return NULL; } token++; } @@ -473,7 +448,7 @@ static bool parse_type(Parser *p, Type *type) { Token *start = t->token; type->kind = TYPE_ARR; t->token++; /* move past [ */ - Token *end = expr_find_end(p, EXPR_END_RSQUARE); + Token *end = expr_find_end(p, 0); type->arr.n_expr = parser_new_expr(p); if (!parse_expr(p, type->arr.n_expr, end)) return false; t->token = end + 1; /* go past ] */ @@ -548,6 +523,7 @@ static bool parse_block(Parser *p, Block *b) { t->token++; /* move past { */ arr_create(&b->stmts, sizeof(Statement)); bool ret = true; + b->ret_expr = NULL; /* default to no return unless overwritten later */ if (!token_is_kw(t->token, KW_RBRACE)) { /* non-empty function body */ while (1) { @@ -555,14 +531,28 @@ static bool parse_block(Parser *p, Block *b) { if (!parse_stmt(p, stmt)) { ret = false; } - if (token_is_kw(t->token, KW_RBRACE)) break; + if (token_is_kw(t->token, KW_RBRACE)) { + if (stmt->kind == STMT_EXPR) { + if (!(stmt->flags & STMT_FLAG_VOIDED_EXPR)) { + b->ret_expr = parser_new_expr(p); + *b->ret_expr = stmt->expr; + arr_remove_last(&b->stmts); /* only keep this expression in the return value */ + } + } + break; + } else if (stmt->kind == STMT_EXPR && !(stmt->flags & STMT_FLAG_VOIDED_EXPR)) { + /* in theory, this should never happen right now */ + err_print(stmt->where, "Non-voided expression is not the last statement in a block (you might want to add a ';' to the end of this statement)."); + return false; + } if (t->token->kind == TOKEN_EOF) { tokr_err(t, "Expected '}' to close function body."); return false; } } + } else { + b->ret_expr = NULL; } - t->token++; /* move past } */ p->block = prev_block; return ret; @@ -627,7 +617,7 @@ static bool parse_args(Parser *p, Array *args) { return false; } Expression *arg = arr_add(args); - if (!parse_expr(p, arg, expr_find_end(p, EXPR_END_RPAREN_OR_COMMA))) { + if (!parse_expr(p, arg, expr_find_end(p, EXPR_CAN_END_WITH_COMMA))) { return false; } if (token_is_kw(t->token, KW_RPAREN)) @@ -859,7 +849,7 @@ static bool parse_expr(Parser *p, Expression *e, Token *end) { if (!parse_expr(p, e->binary.lhs, opening_bracket)) return false; /* parse index */ t->token = opening_bracket + 1; - Token *index_end = expr_find_end(p, EXPR_END_RSQUARE); + Token *index_end = expr_find_end(p, 0); if (!parse_expr(p, e->binary.rhs, index_end)) return false; t->token++; /* move past ] */ @@ -887,6 +877,7 @@ static bool parse_expr(Parser *p, Expression *e, Token *end) { } } tokr_err(t, "Not implemented yet."); + t->token = end + 1; return false; } @@ -942,7 +933,7 @@ static bool parse_expr(Parser *p, Expression *e, Token *end) { case KW_COMMA: op = BINARY_COMMA; break; - default: assert(0); break; + default: assert(0); return false; } e->binary.op = op; e->kind = EXPR_BINARY_OP; @@ -1054,7 +1045,7 @@ static bool parse_single_type_in_decl(Parser *p, Declaration *d) { /* OPTIM: switch t->token->kw ? */ if (token_is_kw(t->token, KW_EQ)) { t->token++; - if (!parse_expr(p, &d->expr, expr_find_end(p, EXPR_END_SEMICOLON))) + if (!parse_expr(p, &d->expr, expr_find_end(p, 0))) return false; d->flags |= DECL_FLAG_HAS_EXPR; if (token_is_kw(t->token, KW_SEMICOLON)) { @@ -1083,12 +1074,12 @@ static bool parse_decl(Parser *p, Declaration *d) { static bool parse_stmt(Parser *p, Statement *s) { Tokenizer *t = p->tokr; + s->flags = 0; if (t->token->kind == TOKEN_EOF) tokr_err(t, "Expected statement."); s->where = t->token->where; /* - NOTE: This may cause problems in the future! Other statements might have comma - as the second token. + TODO: statements such as 3, 5; will not work. */ if (token_is_kw(t->token + 1, KW_COLON) || token_is_kw(t->token + 1, KW_COMMA) || token_is_kw(t->token + 1, KW_AT)) { @@ -1110,22 +1101,33 @@ static bool parse_stmt(Parser *p, Statement *s) { return true; } else { s->kind = STMT_EXPR; - Token *end = expr_find_end(p, EXPR_END_SEMICOLON); + Token *end = expr_find_end(p, 0); + if (token_is_kw(end, KW_SEMICOLON)) { + s->flags |= STMT_FLAG_VOIDED_EXPR; + } if (!end) { tokr_err(t, "No semicolon found at end of statement."); while (t->token->kind != TOKEN_EOF) t->token++; /* move to end of file */ return false; } if (!parse_expr(p, &s->expr, end)) { - t->token = end + 1; return false; } - if (!token_is_kw(t->token, KW_SEMICOLON)) { - tokr_err(t, "Expected ';' at end of statement."); + /* go past end */ + if (end->kind == TOKEN_KW) { + switch (end->kw) { + case KW_SEMICOLON: + t->token = end + 1; + break; + case KW_RBRACE: + t->token = end; /* the } is past the end of the expr */ + break; + default: assert(0); break; + } + } else { t->token = end + 1; - return false; } - t->token++; /* move past ; */ + return true; } } @@ -1293,7 +1295,7 @@ static void fprint_expr(FILE *out, Expression *e) { break; case EXPR_DIRECT: fprintf(out, "#"); - fprintf(out, directives[e->direct.which]); + fprintf(out, "%s", directives[e->direct.which]); fprint_args(out, &e->direct.args); break; } @@ -1321,6 +1323,9 @@ static void fprint_decl(FILE *out, Declaration *d) { static void fprint_stmt(FILE *out, Statement *s) { PARSE_PRINT_LOCATION(s->where); + if (s->flags & STMT_FLAG_VOIDED_EXPR) + fprintf(out, "(void)"); + switch (s->kind) { case STMT_DECL: fprint_decl(out, &s->decl); @@ -1,8 +1,9 @@ #C("#include <stdio.h>\n"); -foo.bar @= fn() { - #C("puts(\"Hello!\");\n"); +foo @= fn() [3]int { + x : [3]int; + x }; - main @= fn() { - foo.bar(); + x := foo(); + #C("printf(\"Foo: %ld\\n\", (long)x)"); }; @@ -1,3 +1,6 @@ +static bool types_stmt(Statement *s); +static bool types_expr(Expression *e); + /* pass NULL for block for global scope */ static bool block_enter(Block *b, Array *stmts) { bool ret = true; @@ -102,10 +105,8 @@ static bool type_eq(Type *a, Type *b) { /* expected must equal got, or an error will be produced */ static bool type_must_eq(Location where, Type *expected, Type *got) { if (!type_eq(expected, got)) { - char str_ex[128]; - char str_got[128]; - type_to_str(expected, str_ex, sizeof str_ex); - type_to_str(got, str_got, sizeof str_got); + char *str_ex = type_to_str(expected); + char *str_got = type_to_str(got); err_print(where, "Type mismatch: expected %s, but got %s.", str_ex, str_got); return false; } @@ -117,8 +118,10 @@ static bool expr_must_lval(Expression *e) { switch (e->kind) { case EXPR_IDENT: { IdentDecl *id_decl = ident_decl(e->ident); - if (!id_decl) + if (!id_decl) { err_print(e->where, "Undeclared identifier."); + return false; + } Declaration *d = id_decl->decl; if (d->flags & DECL_FLAG_CONST) { char *istr = ident_to_str(e->ident); @@ -221,7 +224,7 @@ static bool type_resolve(Type *t) { /* NOTE: this does descend into un/binary ops, etc. but NOT into functions */ static bool type_of_expr(Expression *e, Type *t) { t->flags = 0; - + t->kind = TYPE_UNKNOWN; /* default to unknown type (in the case of an error) */ switch (e->kind) { case EXPR_FN: { FnExpr *f = &e->fn; @@ -259,8 +262,7 @@ static bool type_of_expr(Expression *e, Type *t) { if (!type_of_expr(f, &fn_type)) return false; } if (fn_type.kind != TYPE_FN) { - char type[128]; - type_to_str(&fn_type, type, sizeof type); + char *type = type_to_str(&fn_type); err_print(e->where, "Calling non-function (type %s).", type); return false; } @@ -277,8 +279,7 @@ static bool type_of_expr(Expression *e, Type *t) { switch (e->unary.op) { case UNARY_MINUS: if (of_type->kind != TYPE_BUILTIN || !type_builtin_is_numerical(of_type->builtin)) { - char s[128]; - type_to_str(of_type, s, sizeof s); + char *s = type_to_str(of_type); err_print(e->where, "Cannot apply unary - to non-numerical type %s.", s); return false; } @@ -333,9 +334,9 @@ static bool type_of_expr(Expression *e, Type *t) { } } if (!match) { - char s1[128], s2[128]; - type_to_str(lhs_type, s1, sizeof s1); - type_to_str(rhs_type, s2, sizeof s2); + char *s1, *s2; + s1 = type_to_str(lhs_type); + s2 = type_to_str(rhs_type); const char *op = binary_op_to_str(e->binary.op); err_print(e->where, "Mismatched types to operator %s: %s and %s", op, s1, s2); return false; @@ -382,14 +383,13 @@ static bool type_of_expr(Expression *e, Type *t) { return true; } -static bool types_stmt(Statement *s); - static bool types_block(Block *b) { bool ret = true; if (!block_enter(b, &b->stmts)) return false; arr_foreach(&b->stmts, Statement, s) { if (!types_stmt(s)) ret = false; } + if (b->ret_expr) types_expr(b->ret_expr); if (!block_exit(b, &b->stmts)) return false; return ret; } @@ -399,7 +399,28 @@ static bool types_expr(Expression *e) { if (!type_of_expr(e, t)) return false; switch (e->kind) { case EXPR_FN: - return types_block(&e->fn.body); + if (!types_block(&e->fn.body)) + return false; + assert(e->type.kind == TYPE_FN); + Type *ret_type = e->type.fn.types.data; + Expression *ret_expr = e->fn.body.ret_expr; + if (ret_expr) { + if (!type_eq(ret_type, &ret_expr->type)) { + char *got = type_to_str(&ret_expr->type); + char *expected = type_to_str(ret_type); + err_print(ret_expr->where, "Returning type %s, but function returns type %s.", got, expected); + info_print(e->where, "Function declaration is here."); + free(got); free(expected); + return false; + } + } else if (ret_type->kind != TYPE_VOID) { + /* TODO: this should really be at the closing brace, and not the function declaration */ + char *expected = type_to_str(ret_type); + err_print(e->where, "No return value in function which returns %s.", expected); + free(expected); + return false; + } + break; case EXPR_CALL: { bool ret = true; arr_foreach(&e->call.args, Expression, arg) { @@ -444,6 +465,7 @@ static bool types_decl(Declaration *d) { static bool types_stmt(Statement *s) { switch (s->kind) { case STMT_EXPR: + if (!types_expr(&s->expr)) { return false; } |