summaryrefslogtreecommitdiff
path: root/parse.c
diff options
context:
space:
mode:
Diffstat (limited to 'parse.c')
-rw-r--r--parse.c43
1 files changed, 41 insertions, 2 deletions
diff --git a/parse.c b/parse.c
index 63a8637..84066ad 100644
--- a/parse.c
+++ b/parse.c
@@ -162,13 +162,20 @@ typedef struct FnExpr {
Declaration params; /* declaration of the parameters to this function */
Type ret_type;
Block body;
-} FnExpr; /* an expression such as fn(x: int) int {return 2 * x;} */
+} FnExpr; /* an expression such as fn(x: int) int { 2 * x } */
typedef enum {
STMT_DECL,
- STMT_EXPR
+ STMT_EXPR,
+ STMT_RET
} StatementKind;
+#define RET_FLAG_EXPR 0x01
+typedef struct {
+ uint16_t flags;
+ Expression expr;
+} Return;
+
#define STMT_FLAG_VOIDED_EXPR 0x01 /* the "4;" in fn () { 4; } is a voided expression, but the "4" in fn () int { 4 } is not */
typedef struct {
Location where;
@@ -177,6 +184,7 @@ typedef struct {
union {
Declaration decl;
Expression expr;
+ Return ret;
};
} Statement;
@@ -1311,6 +1319,30 @@ static bool parse_stmt(Parser *p, Statement *s) {
/*
TODO: statements such as 3, 5; will not work.
*/
+ if (token_is_kw(t->token, KW_RETURN)) {
+ s->kind = STMT_RET;
+ t->token++;
+ s->ret.flags = 0;
+ if (token_is_kw(t->token, KW_SEMICOLON)) {
+ /* return with no expr */
+ t->token++;
+ return true;
+ }
+ s->ret.flags |= RET_FLAG_EXPR;
+ Token *end = expr_find_end(p, 0, NULL);
+ if (!end) {
+ while (t->token->kind != TOKEN_EOF) t->token++; /* move to end of file */
+ return false;
+ }
+ if (!token_is_kw(end, KW_SEMICOLON)) {
+ err_print(end->where, "Expected ';' at end of return statement.");
+ t->token = end->kind == TOKEN_EOF ? end : end + 1;
+ return false;
+ }
+ bool success = parse_expr(p, &s->ret.expr, end);
+ t->token = end + 1;
+ return success;
+ }
if (token_is_kw(t->token + 1, KW_COLON) || token_is_kw(t->token + 1, KW_COMMA)
|| token_is_kw(t->token + 1, KW_AT)) {
s->kind = STMT_DECL;
@@ -1498,6 +1530,7 @@ static void fprint_stmt(FILE *out, Statement *s) {
fprintf(out, "(void)");
switch (s->kind) {
+
case STMT_DECL:
fprint_decl(out, &s->decl);
fprintf(out, ";\n");
@@ -1506,6 +1539,12 @@ static void fprint_stmt(FILE *out, Statement *s) {
fprint_expr(out, &s->expr);
fprintf(out, ";\n");
break;
+ case STMT_RET:
+ fprintf(out, "return ");
+ if (s->ret.flags & RET_FLAG_EXPR)
+ fprint_expr(out, &s->ret.expr);
+ fprintf(out, ";\n");
+ break;
}
}