diff --git a/src/socket.c b/src/socket.c index 7203db9797d8b99f7350eaf3aca8950da1435934..aee593cc92701bdf97f06ddd33e62573cb6b57fc 100644 --- a/src/socket.c +++ b/src/socket.c @@ -663,88 +663,59 @@ mrb_win32_basicsocket_close(mrb_state *mrb, mrb_value self) return mrb_nil_value(); } +#define E_EOF_ERROR (mrb_class_get(mrb, "EOFError")) static mrb_value -mrb_win32_basicsocket_read(mrb_state *mrb, mrb_value self) +mrb_win32_basicsocket_sysread(mrb_state *mrb, mrb_value self) { -#define BUF_LEN 4096 - char buf[BUF_LEN]; - int sd, bytes_read; - mrb_int read_len; - mrb_value str = mrb_str_new(mrb, NULL, 0); - mrb_value len_arg = mrb_nil_value(); - - /* For compatibility with IO::read, we accept an object instead of just - * an mrb_int. Then we raise an exception if it isn't an mrb_int or nil. - */ - mrb_get_args(mrb, "|o", &len_arg); - if (mrb_fixnum_p(len_arg)) { - read_len = mrb_fixnum(len_arg); - if (read_len < 0) { - mrb_str_cat_cstr(mrb, str, "negative length: "); - mrb_str_append(mrb, str, len_arg); - mrb_str_cat_cstr(mrb, str, " given"); - mrb_raise(mrb, E_ARGUMENT_ERROR, mrb_string_value_cstr(mrb, &str)); - return mrb_nil_value(); - } else if (read_len == 0) { - return str; - } - } else if (mrb_nil_p(len_arg)) { - read_len = 0; - } else { - mrb_str_cat_cstr(mrb, str, "can't convert "); - mrb_str_append(mrb, str, len_arg); - mrb_str_cat_cstr(mrb, str, " into Integer"); - mrb_raise(mrb, E_TYPE_ERROR, mrb_string_value_cstr(mrb, &str)); + int sd, ret; + mrb_value buf = mrb_nil_value(); + mrb_int maxlen; + + mrb_get_args(mrb, "i|S", &maxlen, &buf); + if (maxlen < 0) { return mrb_nil_value(); } + if (mrb_nil_p(buf)) { + buf = mrb_str_new(mrb, NULL, maxlen); + } + if (RSTRING_LEN(buf) != maxlen) { + buf = mrb_str_resize(mrb, buf, maxlen); + } + sd = socket_fd(mrb, self); - bytes_read = 0; - - /* Behavior of positive integer argument: read until length bytes have been - * read, or until the socket has closed. */ - if (read_len > 0) { - int n = 0; - int max_read_len = BUF_LEN; - int keep_reading = 1; - while (keep_reading) { - if (max_read_len > (read_len - bytes_read)) - max_read_len = read_len - bytes_read; - n = recv(sd, buf, max_read_len, 0); - - if (n == SOCKET_ERROR) { - mrb_sys_fail(mrb, "recv"); - } else if (n == 0) { - keep_reading = 0; + ret = recv(sd, RSTRING_PTR(buf), maxlen, 0); + + switch (ret) { + case 0: /* EOF */ + if (maxlen == 0) { + buf = mrb_str_new_cstr(mrb, ""); } else { - mrb_str_cat(mrb, str, buf, n); - bytes_read += n; - if (bytes_read >= read_len) - keep_reading = 0; + mrb_raise(mrb, E_EOF_ERROR, "sysread failed: End of File"); } - } - } - /* Behavior of nil/default argument: read until socket has closed */ - else { - int keep_reading = 1; - while (keep_reading) { - bytes_read = recv(sd, buf, BUF_LEN, 0); - if (bytes_read == SOCKET_ERROR) { - mrb_sys_fail(mrb, "recv"); - } else if (bytes_read == 0) { - keep_reading = 0; - } else { - mrb_str_cat(mrb, str, buf, bytes_read); + break; + case SOCKET_ERROR: /* Error */ + mrb_sys_fail(mrb, "recv"); + break; + default: + if (RSTRING_LEN(buf) != ret) { + buf = mrb_str_resize(mrb, buf, ret); } - } + break; } - return str; -#undef BUF_LEN + return buf; +} + +static mrb_value +mrb_win32_basicsocket_sysseek(mrb_state *mrb, mrb_value self) +{ + mrb_raise(mrb, E_NOTIMP_ERROR, "sysseek not implemented for windows sockets"); + return mrb_nil_value(); } static mrb_value -mrb_win32_basicsocket_write(mrb_state *mrb, mrb_value self) +mrb_win32_basicsocket_syswrite(mrb_state *mrb, mrb_value self) { int n; SOCKET sd; @@ -842,8 +813,9 @@ mrb_mruby_socket_gem_init(mrb_state* mrb) /* Windows IO Method Overrides on BasicSocket */ #ifdef _WIN32 mrb_define_method(mrb, bsock, "close", mrb_win32_basicsocket_close, MRB_ARGS_NONE()); - mrb_define_method(mrb, bsock, "read", mrb_win32_basicsocket_read, MRB_ARGS_OPT(1)); - mrb_define_method(mrb, bsock, "write", mrb_win32_basicsocket_write, MRB_ARGS_REQ(1)); + mrb_define_method(mrb, bsock, "sysread", mrb_win32_basicsocket_sysread, MRB_ARGS_REQ(1)|MRB_ARGS_OPT(1)); + mrb_define_method(mrb, bsock, "sysseek", mrb_win32_basicsocket_sysseek, MRB_ARGS_REQ(1)); + mrb_define_method(mrb, bsock, "syswrite", mrb_win32_basicsocket_syswrite, MRB_ARGS_REQ(1)); #endif constants = mrb_define_module_under(mrb, sock, "Constants");