From 7bb5ac5863bdb4bc7af04ee18e81657a85aa97f2 Mon Sep 17 00:00:00 2001
From: Leo Tenenbaum <pommicket@gmail.com>
Date: Fri, 13 Mar 2020 15:32:12 -0400
Subject: added where conditions for functions, discovered bug with eval
 returning

---
 copy.c               |  6 +++++-
 eval.c               | 26 +++++++++++++++++---------
 main.c               |  2 +-
 parse.c              | 18 +++++++++++++++---
 test.toc             | 43 ++++++++++++++++++++-----------------------
 tests/test.sh        |  1 +
 tests/where.toc      | 30 ++++++++++++++++++++++++++++++
 tests/where_expected |  1 +
 tokenizer.c          |  4 ++++
 types.c              | 26 ++++++++++++++++++++++++++
 types.h              |  4 +++-
 11 files changed, 123 insertions(+), 38 deletions(-)
 create mode 100644 tests/where.toc
 create mode 100644 tests/where_expected

diff --git a/copy.c b/copy.c
index 6f39dab..4270086 100644
--- a/copy.c
+++ b/copy.c
@@ -227,8 +227,12 @@ static void copy_fn_expr(Copier *c, FnExpr *fout, FnExpr *fin, U8 flags) {
 					copy_decl(c, fout->ret_decls + i, fin->ret_decls + i);
 			}
 			copy_type(c, &fout->ret_type, &fin->ret_type);
+			
+			if (fin->condition) {
+				fout->condition = copy_expr_(c, fin->condition);
+			}
+			c->block = prev;
 			if (copy_body) {
-				c->block = prev;
 				copy_block(c, &fout->body, &fin->body, copy_body ? COPY_BLOCK_DONT_CREATE_IDENTS : 0);
 			}
 		}
diff --git a/eval.c b/eval.c
index f73f600..c2efc5d 100644
--- a/eval.c
+++ b/eval.c
@@ -1653,10 +1653,13 @@ static Status eval_stmt(Evaluator *ev, Statement *stmt) {
 			return false;
 	} break;
 	case STMT_RET: {
-		Value r;
-		if (!eval_expr(ev, &stmt->ret.expr, &r))
-			return false;
-		copy_val(NULL, &ev->ret_val, r, &stmt->ret.expr.type);
+		if (stmt->ret.flags & RET_HAS_EXPR) {
+			Value r;
+			if (!eval_expr(ev, &stmt->ret.expr, &r))
+				return false;
+			copy_val(NULL, &ev->ret_val, r, &stmt->ret.expr.type);
+		}
+		ev->returning = true;
 	} break;
 	case STMT_INCLUDE:
 		arr_foreach(stmt->inc.stmts, Statement, sub)
@@ -1667,13 +1670,14 @@ static Status eval_stmt(Evaluator *ev, Statement *stmt) {
 	return true;
 }
 
-static void eval_exit_stmts(Statement *stmts) {
-	arr_foreach(stmts, Statement, s) {
+static void eval_exit_stmts(Statement *stmts, Statement *last_reached) {
+    for (Statement *s = stmts; s <= last_reached; ++s) {
 		if (s->kind == STMT_DECL && !(s->decl->flags & DECL_IS_CONST)) {
 			Declaration *d = s->decl;
 			decl_remove_val(d);
 		} else if (s->kind == STMT_INCLUDE) {
-			eval_exit_stmts(s->inc.stmts);
+			/* TODO: this doesn't work!!! */
+			eval_exit_stmts(s->inc.stmts, arr_last(s->inc.stmts));
 		}
 	}
 }
@@ -1682,12 +1686,16 @@ static Status eval_block(Evaluator *ev, Block *b, Value *v) {
 	Block *prev = ev->typer->block;
 	ev->typer->block = b;
 	bool success = true;
+	Statement *last_reached = arr_last(b->stmts);
 	arr_foreach(b->stmts, Statement, stmt) {
 		if (!eval_stmt(ev, stmt)) {
 			success = false;
 			goto ret;
 		}
-		if (ev->returning) break;
+		if (ev->returning) {
+			last_reached = stmt;
+			break;
+		}
 	}
 	if (!ev->returning && b->ret_expr) {
 		Value r;
@@ -1704,7 +1712,7 @@ static Status eval_block(Evaluator *ev, Block *b, Value *v) {
 			*v = r;
 		}
 	}
-	eval_exit_stmts(b->stmts);
+	eval_exit_stmts(b->stmts, last_reached);
  ret:
 	ev->typer->block = prev;
 	return success;
diff --git a/main.c b/main.c
index a718200..0b0b883 100644
--- a/main.c
+++ b/main.c
@@ -8,7 +8,7 @@
 
 /* 
 TODO:
-where
+fix eval returning from included stuff (see: TODO: this doesn't work!!!)
 #returns_code (function/struct body is a block, to be evaluated at compile time, which returns the actual statements -- you can use this for implementation of printf)
 	- struct varargs
 break
diff --git a/parse.c b/parse.c
index 583f637..704d867 100644
--- a/parse.c
+++ b/parse.c
@@ -331,6 +331,7 @@ typedef enum {
 			  EXPR_CAN_END_WITH_COLON = 0x04,
 			  EXPR_CAN_END_WITH_DOTDOT = 0x08,
 			  EXPR_CAN_END_WITH_EQ = 0x10,
+			  EXPR_CAN_END_WITH_WHERE = 0x20
 			  /* note that parse_type uses -1 for this */
 } ExprEndFlags;
 
@@ -395,6 +396,10 @@ static Token *expr_find_end(Parser *p, ExprEndFlags flags)  {
 				if (all_levels_0 && (flags & EXPR_CAN_END_WITH_EQ))
 					return token;
 				break;
+			case KW_WHERE:
+				if (all_levels_0 && (flags & EXPR_CAN_END_WITH_WHERE))
+					return token;
+				break;
 			case KW_COLON:
 				if ((flags & EXPR_CAN_END_WITH_COLON) && all_levels_0)
 					return token;
@@ -802,7 +807,7 @@ static bool parser_is_definitely_type(Parser *p, Token **end) {
 							--paren_level;
 							if (paren_level == 0) {
 								++t->token;
-								if (token_is_kw(t->token, KW_LBRACE)) goto end; /* void fn expr */
+								if (token_is_kw(t->token, KW_LBRACE) || token_is_kw(t->token, KW_WHERE)) goto end; /* void fn expr */
 								if (is_decl(t)) /* has return declaration */
 									goto end;
 								
@@ -954,6 +959,7 @@ static Status parse_fn_expr(Parser *p, FnExpr *f) {
 	f->instance_id = 0;
 	f->ret_decls = NULL;
 	f->instances = NULL;
+	f->condition = NULL;
 	/* only called when token is fn */
 	assert(token_is_kw(t->token, KW_FN));
 	++t->token;
@@ -984,7 +990,7 @@ static Status parse_fn_expr(Parser *p, FnExpr *f) {
 	    success = false; goto ret;
 	}
 	
-	if (token_is_kw(t->token, KW_LBRACE)) {
+	if (token_is_kw(t->token, KW_LBRACE) || token_is_kw(t->token, KW_WHERE)) {
 		/* void function */
 		f->ret_type.kind = TYPE_VOID;
 		f->ret_type.flags = 0;
@@ -1011,6 +1017,13 @@ static Status parse_fn_expr(Parser *p, FnExpr *f) {
 			goto ret;
 		}
 	}
+	if (token_is_kw(t->token, KW_WHERE)) {
+		++t->token;
+		f->condition = parser_new_expr(p);
+		if (!parse_expr(p, f->condition, expr_find_end(p, EXPR_CAN_END_WITH_LBRACE))) {
+			return false;
+		}
+	}
 	p->block = prev_block; /* be nice to parse_block */
 	if (!parse_block(p, &f->body, PARSE_BLOCK_DONT_CREATE_IDENTS))
 		success = false;
@@ -2194,7 +2207,6 @@ static Status parse_expr(Parser *p, Expression *e, Token *end) {
 					return false;
 				goto success;
 			}
-		
 			tokr_err(t, "Unrecognized expression.");
 			return false;
 		}
diff --git a/test.toc b/test.toc
index 0885da4..4223656 100644
--- a/test.toc
+++ b/test.toc
@@ -1,33 +1,30 @@
 printf ::= #foreign("printf","libc.so.6") fn(#C &"const char", #C ..) #C int;
 
-tprintf ::= fn(fmt: []char, args: ..) {
-	printf(&fmt[0], args);
-};
-
-sum ::= fn(x: ..) int {
-	total := 0;
-	for a, i := x { 
-		total += a + i - i + 1;
+tprintf_valid ::= fn(fmt :: []char, nargs: int) bool {
+	if fmt[fmt.len-1] != '\0' {
+		return false;
 	}
-	total - x.len
-};
-
-sumc ::= fn(x:: ..) int {
-	total := 0;
-	for a, i := x { 
-		total += a + i - i + 1;
+	count := 0;
+	for x, i := fmt {
+		if x == '%' {
+			if i == fmt.len-1 {
+				count += 1;
+			} elif fmt[i+1] != '%' {
+				count += 1;
+			} else {
+				count -= 1;
+			}
+		}
 	}
-	total - x.len
+	count == nargs
 };
+	
 
-do_printing ::= fn(x::..) {
-	tprintf("%ld\n",sum(x));
-	tprintf("%ld\n",sumc(x));
+tprintf ::= fn(fmt :: []char, args: ..) where tprintf_valid(fmt, args.len) {
+	f := fmt;
+	printf(&f[0], args);
 };
 
 main ::= fn() {
-	do_printing();
-	do_printing(1,2,3);
-	do_printing(4);
-	do_printing(1,10,100,1000,10000);
+	 tprintf("%d %d%%\n\0", 3, 4);
 };
diff --git a/tests/test.sh b/tests/test.sh
index 5b2cb14..31ccb6b 100755
--- a/tests/test.sh
+++ b/tests/test.sh
@@ -8,6 +8,7 @@ foreign
 params
 nms
 varargs
+where
 misc'
 
 STARTPWD=$(pwd)
diff --git a/tests/where.toc b/tests/where.toc
new file mode 100644
index 0000000..4223656
--- /dev/null
+++ b/tests/where.toc
@@ -0,0 +1,30 @@
+printf ::= #foreign("printf","libc.so.6") fn(#C &"const char", #C ..) #C int;
+
+tprintf_valid ::= fn(fmt :: []char, nargs: int) bool {
+	if fmt[fmt.len-1] != '\0' {
+		return false;
+	}
+	count := 0;
+	for x, i := fmt {
+		if x == '%' {
+			if i == fmt.len-1 {
+				count += 1;
+			} elif fmt[i+1] != '%' {
+				count += 1;
+			} else {
+				count -= 1;
+			}
+		}
+	}
+	count == nargs
+};
+	
+
+tprintf ::= fn(fmt :: []char, args: ..) where tprintf_valid(fmt, args.len) {
+	f := fmt;
+	printf(&f[0], args);
+};
+
+main ::= fn() {
+	 tprintf("%d %d%%\n\0", 3, 4);
+};
diff --git a/tests/where_expected b/tests/where_expected
new file mode 100644
index 0000000..622c000
--- /dev/null
+++ b/tests/where_expected
@@ -0,0 +1 @@
+3 4%
diff --git a/tokenizer.c b/tokenizer.c
index 15f7452..17fbde4 100644
--- a/tokenizer.c
+++ b/tokenizer.c
@@ -168,6 +168,10 @@ static Location token_location(File *file, Token *t) {
 	return loc;
 }
 
+static void print_token_location(File *file, Token *t) {
+	print_location(token_location(file, t));
+}
+
 /* for use during tokenization */
 static void tokenization_err_(
 #if ERR_SHOW_SOURCE_LOCATION
diff --git a/types.c b/types.c
index 1af49f1..d04e149 100644
--- a/types.c
+++ b/types.c
@@ -2358,6 +2358,32 @@ static Status types_expr(Typer *tr, Expression *e) {
 					}	
 				}
 			}
+
+			if (fn_copy->condition) {
+				typer_block_enter(tr, &fn_copy->body);
+				/* check where condition */
+				if (!types_expr(tr, fn_copy->condition)) {
+					typer_block_exit(tr);
+					return false;
+				}
+				typer_block_exit(tr);
+				
+				Type *condition_type = &fn_copy->condition->type;
+				if (!type_is_builtin(condition_type, BUILTIN_BOOL)) {
+					char *s = type_to_str(condition_type);
+					err_print(fn_copy->condition->where, "where conditions must be of type bool, but this is of type %s.", s);
+					free(s);
+					return false;
+				}
+				Value val;
+				if (!eval_expr(tr->evalr, fn_copy->condition, &val)) {
+					return false;
+				}
+				if (!val.boolv) {
+					err_print(fn_copy->condition->where, "Function where condition not satisfied. You are probably calling this function incorrectly.");
+					return false;
+				}
+			}
 			
 			ret_type = f->type.fn.types;
 			param_types = ret_type + 1;
diff --git a/types.h b/types.h
index 1734393..2ecb331 100644
--- a/types.h
+++ b/types.h
@@ -317,6 +317,7 @@ typedef enum {
 			  KW_FALSE,
 			  KW_NMS,
 			  KW_TYPEOF,
+			  KW_WHERE,
 			  KW_COUNT
 } Keyword;
 
@@ -331,7 +332,7 @@ static const char *const keywords[KW_COUNT] =
 	 "int", "i8", "i16", "i32", "i64",
 	 "u8", "u16", "u32", "u64", "float", "f32", "f64", "Type",
 	 "Namespace",
-	 "char", "bool", "true", "false", "nms", "typeof"};
+	 "char", "bool", "true", "false", "nms", "typeof", "where"};
 
 typedef enum {
 			  NUM_LITERAL_INT,
@@ -689,6 +690,7 @@ typedef struct FnExpr {
 			U64 instance_id;
 			Type ret_type;	
 			Block body;
+			struct Expression *condition; /* fn(...) ... where ...  */
 		};
 		struct {
 			Type type; /* type of this function */
-- 
cgit v1.2.3