Downloader: Protect against partially downloaded files. (#954)

* Downloader: Protect against partially downloaded files.

* Cleanup

* Add 1 minute timeout.

* Checkstyle
This commit is contained in:
modmuss
2023-09-22 18:55:44 +01:00
committed by GitHub
parent 0b36121357
commit bd09af1783

View File

@@ -62,9 +62,11 @@ import net.fabricmc.loom.util.Checksum;
public final class Download { public final class Download {
private static final String E_TAG = "ETag"; private static final String E_TAG = "ETag";
private static final Logger LOGGER = LoggerFactory.getLogger(Download.class); private static final Logger LOGGER = LoggerFactory.getLogger(Download.class);
private static final Duration TIMEOUT = Duration.ofMinutes(1);
private static final HttpClient HTTP_CLIENT = HttpClient.newBuilder() private static final HttpClient HTTP_CLIENT = HttpClient.newBuilder()
.followRedirects(HttpClient.Redirect.ALWAYS) .followRedirects(HttpClient.Redirect.ALWAYS)
.proxy(ProxySelector.getDefault()) .proxy(ProxySelector.getDefault())
.connectTimeout(TIMEOUT)
.build(); .build();
public static DownloadBuilder create(String url) throws URISyntaxException { public static DownloadBuilder create(String url) throws URISyntaxException {
@@ -93,17 +95,20 @@ public final class Download {
this.downloadAttempt = downloadAttempt; this.downloadAttempt = downloadAttempt;
} }
private HttpRequest getRequest() { private HttpRequest.Builder requestBuilder() {
return HttpRequest.newBuilder(url) return HttpRequest.newBuilder(url)
.timeout(TIMEOUT)
.version(httpVersion) .version(httpVersion)
.GET() .GET();
}
private HttpRequest getRequest() {
return requestBuilder()
.build(); .build();
} }
private HttpRequest getETagRequest(String etag) { private HttpRequest getETagRequest(String etag) {
return HttpRequest.newBuilder(url) return requestBuilder()
.version(httpVersion)
.GET()
.header("If-None-Match", etag) .header("If-None-Match", etag)
.build(); .build();
} }
@@ -190,47 +195,12 @@ public final class Download {
return; return;
} }
if (success) { if (!success) {
try {
Files.deleteIfExists(output);
} catch (IOException e) {
throw error(e, "Failed to delete existing file");
}
final long length = Long.parseLong(response.headers().firstValue("Content-Length").orElse("-1"));
AtomicLong totalBytes = new AtomicLong(0);
try (OutputStream outputStream = Files.newOutputStream(output, StandardOpenOption.CREATE_NEW)) {
copyWithCallback(decodeOutput(response), outputStream, value -> {
if (length < 0) {
return;
}
progressListener.onProgress(totalBytes.addAndGet(value), length);
});
} catch (IOException e) {
throw error(e, "Failed to decode and write download output");
}
if (Files.notExists(output)) {
throw error("No file was downloaded");
}
if (length > 0) {
try {
final long actualLength = Files.size(output);
if (actualLength != length) {
throw error("Unexpected file length of %d bytes, expected %d bytes".formatted(actualLength, length));
}
} catch (IOException e) {
throw error(e);
}
}
} else {
throw statusError("HTTP request returned unsuccessful status (%d)", statusCode); throw statusError("HTTP request returned unsuccessful status (%d)", statusCode);
} }
downloadToPath(output, response);
if (useEtag) { if (useEtag) {
final HttpHeaders headers = response.headers(); final HttpHeaders headers = response.headers();
final String responseETag = headers.firstValue(E_TAG.toLowerCase(Locale.ROOT)).orElse(null); final String responseETag = headers.firstValue(E_TAG.toLowerCase(Locale.ROOT)).orElse(null);
@@ -260,6 +230,58 @@ public final class Download {
} }
} }
private void downloadToPath(Path output, HttpResponse<InputStream> response) throws DownloadException {
// Download the file initially to a .part file
final Path partFile = getPartFile(output);
try {
Files.deleteIfExists(output);
Files.deleteIfExists(partFile);
} catch (IOException e) {
throw error(e, "Failed to delete existing file");
}
final long length = Long.parseLong(response.headers().firstValue("Content-Length").orElse("-1"));
AtomicLong totalBytes = new AtomicLong(0);
try (OutputStream outputStream = Files.newOutputStream(partFile, StandardOpenOption.CREATE_NEW)) {
copyWithCallback(decodeOutput(response), outputStream, value -> {
if (length < 0) {
return;
}
progressListener.onProgress(totalBytes.addAndGet(value), length);
});
} catch (IOException e) {
throw error(e, "Failed to decode and write download output");
}
if (Files.notExists(partFile)) {
throw error("No file was downloaded");
}
if (length > 0) {
try {
final long actualLength = Files.size(partFile);
if (actualLength != length) {
throw error("Unexpected file length of %d bytes, expected %d bytes".formatted(actualLength, length));
}
} catch (IOException e) {
throw error(e);
}
}
try {
// Once the file has been fully read, create a hard link to the destination file.
// And then remove the temporary file, this ensures that the output file only exists in fully populated state.
Files.createLink(output, partFile);
Files.delete(partFile);
} catch (IOException e) {
throw error(e, "Failed to complete download");
}
}
private void copyWithCallback(InputStream is, OutputStream os, IntConsumer consumer) throws IOException { private void copyWithCallback(InputStream is, OutputStream os, IntConsumer consumer) throws IOException {
byte[] buffer = new byte[1024]; byte[] buffer = new byte[1024];
int length; int length;
@@ -389,6 +411,18 @@ public final class Download {
} catch (IOException ignored) { } catch (IOException ignored) {
// ignored // ignored
} }
try {
Files.deleteIfExists(getLockFile(output));
} catch (IOException ignored) {
// ignored
}
try {
Files.deleteIfExists(getPartFile(output));
} catch (IOException ignored) {
// ignored
}
} }
// A faster exists check // A faster exists check
@@ -405,6 +439,10 @@ public final class Download {
return output.resolveSibling(output.getFileName() + ".lock"); return output.resolveSibling(output.getFileName() + ".lock");
} }
private Path getPartFile(Path output) {
return output.resolveSibling(output.getFileName() + ".part");
}
private boolean getAndResetLock(Path output) throws DownloadException { private boolean getAndResetLock(Path output) throws DownloadException {
final Path lock = getLockFile(output); final Path lock = getLockFile(output);
final boolean exists = exists(lock); final boolean exists = exists(lock);