summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--copy.c97
-rw-r--r--parse.c13
-rw-r--r--test.toc10
-rw-r--r--types.c71
4 files changed, 163 insertions, 28 deletions
diff --git a/copy.c b/copy.c
index f36ebfb..279abf6 100644
--- a/copy.c
+++ b/copy.c
@@ -1,6 +1,7 @@
/* these copy functions MUST be used before typing!!!! (except for copy_val) */
static void copy_expr(Allocator *a, Expression *out, Expression *in);
+static void copy_decl(Allocator *a, Declaration *out, Declaration *in);
static void copy_block(Allocator *a, Block *out, Block *in);
static void copy_val(Allocator *allocr, Value *out, Value *in, Type *t) {
@@ -88,6 +89,23 @@ static void copy_type(Allocator *a, Type *out, Type *in) {
}
}
+static void copy_fn_expr(Allocator *a, FnExpr *fout, FnExpr *fin, bool copy_body) {
+ size_t i;
+ fout->params = NULL;
+ size_t nparam_decls = arr_len(fin->params);
+ arr_set_lena(&fout->params, nparam_decls, a);
+ for (i = 0; i < nparam_decls; i++)
+ copy_decl(a, fout->params + i, fin->params + i);
+ size_t nret_decls = arr_len(fin->ret_decls);
+ fout->ret_decls = NULL;
+ arr_set_lena(&fout->ret_decls, nret_decls, a);
+ for (i = 0; i < nret_decls; i++)
+ copy_decl(a, fout->ret_decls + i, fin->ret_decls + i);
+ copy_type(a, &fout->ret_type, &fin->ret_type);
+ if (copy_body)
+ copy_block(a, &fout->body, &fin->body);
+}
+
static void copy_expr(Allocator *a, Expression *out, Expression *in) {
*out = *in;
switch (in->kind) {
@@ -136,7 +154,79 @@ static void copy_expr(Allocator *a, Expression *out, Expression *in) {
}
copy_block(a, &eout->body, &ein->body);
} break;
+ case EXPR_FN:
+ copy_fn_expr(a, &out->fn, &in->fn, true);
+ break;
+ case EXPR_CAST: {
+ CastExpr *cin = &in->cast;
+ CastExpr *cout = &out->cast;
+ copy_type(a, &cout->type, &cin->type);
+ copy_expr(a, cout->expr = allocr_malloc(a, sizeof *cout->expr), cin->expr);
+ } break;
+ case EXPR_NEW: {
+ NewExpr *nin = &in->new;
+ NewExpr *nout = &out->new;
+ copy_type(a, &nout->type, &nin->type);
+ if (nin->n) copy_expr(a, nout->n = allocr_malloc(a, sizeof *nout->n), nin->n);
+ } break;
+ case EXPR_CALL: {
+ CallExpr *cin = &in->call;
+ CallExpr *cout = &out->call;
+ copy_expr(a, cout->fn = allocr_malloc(a, sizeof *cout->fn), cin->fn);
+ size_t nargs = arr_len(cin->arg_exprs);
+ cout->arg_exprs = NULL;
+ arr_set_lena(&cout->arg_exprs, nargs, a);
+ for (size_t i = 0; i < nargs; i++) {
+ copy_expr(a, cout->arg_exprs + i, cin->arg_exprs + i);
+ }
+ } break;
+ case EXPR_BLOCK:
+ copy_block(a, &out->block, &in->block);
+ break;
+ case EXPR_TUPLE: {
+ size_t nexprs = arr_len(in->tuple);
+ out->tuple = NULL;
+ arr_set_lena(&out->tuple, nexprs, a);
+ for (size_t i = 0; i < nexprs; i++)
+ copy_expr(a, out->tuple + i, in->tuple + i);
+ } break;
+ case EXPR_C:
+ copy_expr(a, out->c.code = allocr_malloc(a, sizeof *out->c.code), in->c.code);
+ break;
+ case EXPR_DSIZEOF:
+ copy_expr(a, out->dsizeof.of = allocr_malloc(a, sizeof *out->dsizeof.of), in->dsizeof.of);
+ break;
+ case EXPR_DALIGNOF:
+ copy_expr(a, out->dalignof.of = allocr_malloc(a, sizeof *out->dalignof.of), in->dalignof.of);
+ break;
+ case EXPR_SLICE: {
+ SliceExpr *sin = &in->slice;
+ SliceExpr *sout = &out->slice;
+ copy_expr(a, sout->of = allocr_malloc(a, sizeof *sout->of), sin->of);
+ if (sin->from)
+ copy_expr(a, sout->from = allocr_malloc(a, sizeof *sout->from), sin->from);
+ if (sin->to)
+ copy_expr(a, sout->to = allocr_malloc(a, sizeof *sout->to), sin->to);
+ } break;
+ case EXPR_TYPE:
+ copy_type(a, &out->typeval, &in->typeval);
+ break;
+ case EXPR_VAL:
+ copy_val(a, &out->val, &in->val, &in->type);
+ break;
+ }
+}
+
+static void copy_decl(Allocator *a, Declaration *out, Declaration *in) {
+ *out = *in;
+ if (in->flags & DECL_HAS_EXPR)
+ copy_expr(a, &out->expr, &in->expr);
+ if (in->flags & DECL_FOUND_VAL) {
+ copy_val(a, &out->val, &in->val, &in->type);
}
+ if (in->flags & DECL_ANNOTATES_TYPE)
+ copy_type(a, &out->type, &in->type);
+
}
static void copy_stmt(Allocator *a, Statement *out, Statement *in) {
@@ -151,12 +241,7 @@ static void copy_stmt(Allocator *a, Statement *out, Statement *in) {
copy_expr(a, &out->expr, &in->expr);
break;
case STMT_DECL:
- copy_expr(a, &out->decl.expr, &in->decl.expr);
- if (in->decl.flags & DECL_FOUND_VAL) {
- copy_val(a, &out->decl.val, &in->decl.val, &in->decl.type);
- }
- if (in->decl.flags & DECL_ANNOTATES_TYPE)
- copy_type(a, &out->decl.type, &in->decl.type);
+ copy_decl(a, &out->decl, &in->decl);
break;
}
}
diff --git a/parse.c b/parse.c
index 6f6357a..0d79f8f 100644
--- a/parse.c
+++ b/parse.c
@@ -7,6 +7,12 @@ static bool parse_decl(Parser *p, Declaration *d, DeclEndKind ends_with, uint16_
static bool is_decl(Tokenizer *t);
static inline bool ends_decl(Token *t, DeclEndKind ends_with);
+static bool fn_has_any_const_params(FnExpr *f) {
+ arr_foreach(f->params, Declaration, param)
+ if (param->flags & (DECL_IS_CONST | DECL_SEMI_CONST))
+ return true;
+ return false;
+}
static const char *expr_kind_to_str(ExprKind k) {
switch (k) {
@@ -1952,7 +1958,13 @@ static void fprint_fn_expr(FILE *out, FnExpr *f) {
fprintf(out, ") ");
fprint_type(out, &f->ret_type);
fprintf(out, " ");
+ bool anyc = fn_has_any_const_params(f);
+ bool prev = parse_printing_after_types;
+ if (anyc)
+ parse_printing_after_types = false;
fprint_block(out, &f->body);
+ if (anyc)
+ parse_printing_after_types = prev;
}
static void fprint_args(FILE *out, Argument *args) {
@@ -2251,3 +2263,4 @@ static bool expr_is_definitely_const(Expression *e) {
assert(0);
return false;
}
+
diff --git a/test.toc b/test.toc
index 3906c48..ddc09ab 100644
--- a/test.toc
+++ b/test.toc
@@ -8,11 +8,11 @@ puti @= fn(x: int) {
// };
-stuff @= fn(t @ Type) int {
- 4327834 as t as int
-};
-
main @= fn() {
- puti(stuff(int));
+ puti(f(17));
};
+
+f @= fn(k @ int) int {
+k
+}; \ No newline at end of file
diff --git a/types.c b/types.c
index 8b5750c..300438f 100644
--- a/types.c
+++ b/types.c
@@ -185,6 +185,14 @@ static bool type_of_fn(Typer *tr, FnExpr *f, Location where, Type *t) {
t->kind = TYPE_FN;
t->fn.types = NULL;
t->fn.constness = NULL; /* OPTIM: constant doesn't need to be a dynamic array */
+ FnExpr *newf = NULL;
+ if (fn_has_any_const_params(f)) {
+ /* OPTIM don't copy so much */
+ newf = typer_malloc(tr, sizeof *newf);
+ copy_fn_expr(tr->allocr, newf, f, false);
+ f = newf;
+ }
+
bool has_constant_params = false;
Type *ret_type = typer_arr_add(tr, &t->fn.types);
if (f->ret_decls && f->ret_type.kind == TYPE_VOID /* haven't found return type yet */) {
@@ -264,7 +272,7 @@ static bool type_of_fn(Typer *tr, FnExpr *f, Location where, Type *t) {
arr_foreach(f->ret_decls, Declaration, decl) {
if (!types_decl(tr, decl)) return false;
}
- return true;
+ return true;
}
static bool type_of_ident(Typer *tr, Location where, Identifier i, Type *t) {
@@ -589,7 +597,9 @@ static bool arg_is_const(Expression *arg, Constness constness) {
}
-static bool types_fn(Typer *tr, FnExpr *f, Type *t, Location where) {
+/* pass NULL for instance if this isn't an instance */
+static bool types_fn(Typer *tr, FnExpr *f, Type *t, Location where,
+ Instance *instance) {
FnExpr *prev_fn = tr->fn;
bool success = true;
{
@@ -598,10 +608,34 @@ static bool types_fn(Typer *tr, FnExpr *f, Type *t, Location where) {
}
assert(t->kind == TYPE_FN);
-
- /* don't type function body yet; we need to do that for every instance */
- if (t->fn.constness)
- return true;
+
+
+ if (instance) {
+ copy_fn_expr(tr->allocr, &instance->fn, f, true);
+ f = &instance->fn;
+ Value *compile_time_args = instance->val.tuple;
+ U64 which_are_const = compile_time_args[0].u64;
+ compile_time_args++;
+ int compile_time_arg_idx = 0;
+ int semi_const_arg_idx = 0;
+ arr_foreach(f->params, Declaration, param) {
+ if (param->flags & DECL_IS_CONST) {
+ param->val = compile_time_args[compile_time_arg_idx];
+ param->flags |= DECL_FOUND_VAL;
+ compile_time_arg_idx++;
+ } else if (param->flags & DECL_SEMI_CONST) {
+ if (which_are_const & (((U64)1) << semi_const_arg_idx)) {
+ param->val = compile_time_args[compile_time_arg_idx];
+ param->flags |= DECL_FOUND_VAL | DECL_IS_CONST; /* pretend it's constant */
+ compile_time_arg_idx++;
+ }
+ semi_const_arg_idx++;
+ }
+ }
+ } else {
+ if (t->fn.constness)
+ return true; /* don't type function body yet; we need to do that for every instance */
+ }
tr->fn = f;
if (!fn_enter(f, SCOPE_CHECK_REDECL)) {
@@ -665,12 +699,17 @@ static bool types_expr(Typer *tr, Expression *e) {
t->kind = TYPE_UNKNOWN; /* default to unknown type (in the case of an error) */
e->flags |= EXPR_FOUND_TYPE; /* even if failed, pretend we found the type */
switch (e->kind) {
- case EXPR_FN:
+ case EXPR_FN: {
if (!type_of_fn(tr, &e->fn, e->where, &e->type))
- return false;
- if (!types_fn(tr, &e->fn, &e->type, e->where))
return false;
- break;
+ if (fn_has_any_const_params(&e->fn)) {
+ HashTable z = {0};
+ e->fn.instances = z;
+ } else {
+ if (!types_fn(tr, &e->fn, &e->type, e->where, NULL))
+ return false;
+ }
+ } break;
case EXPR_LITERAL_INT:
t->kind = TYPE_BUILTIN;
t->builtin = BUILTIN_I64;
@@ -946,12 +985,8 @@ static bool types_expr(Typer *tr, Expression *e) {
CallExpr *c = &e->call;
c->instance = NULL;
Expression *f = c->fn;
- if (f->kind == EXPR_IDENT) {
- /* allow calling a function before declaring it */
- if (!type_of_ident(tr, f->where, f->ident, &f->type)) return false;
- } else {
- if (!types_expr(tr, f)) return false;
- }
+ FnExpr *fn_decl = NULL;
+ if (!types_expr(tr, f)) return false;
arr_foreach(c->args, Argument, arg) {
if (!types_expr(tr, &arg->val))
return false;
@@ -971,7 +1006,6 @@ static bool types_expr(Typer *tr, Expression *e) {
size_t nparams = arr_len(f->type.fn.types) - 1;
size_t nargs = arr_len(c->args);
bool ret = true;
- FnExpr *fn_decl = NULL;
Expression *new_args = NULL;
arr_set_lena(&new_args, nparams, tr->allocr);
bool *params_set = nparams ? typer_calloc(tr, nparams, sizeof *params_set) : NULL;
@@ -1132,6 +1166,9 @@ static bool types_expr(Typer *tr, Expression *e) {
c->instance = instance_table_adda(tr->allocr, &fn->instances, table_index, &table_index_type, &instance_already_exists);
c->instance->c.id = fn->instances.n; /* let's help cgen out and assign an ID to this */
arr_clear(&table_index_type.tuple);
+ /* type this instance */
+ if (!types_fn(tr, fn, &f->type, e->where, c->instance))
+ return false;
}
*t = *ret_type;
c->arg_exprs = new_args;